Compare commits
No commits in common. "e938018494a1a5b512eb2a40ae9c4d5155d81299" and "df1ba6fe53c1fd3c2b490a28df0e586a49a48df8" have entirely different histories.
e938018494
...
df1ba6fe53
@ -49,7 +49,7 @@ The TRPL implementation in Fancy RL includes projections based on the Kullback-L
|
||||
To run the test suite:
|
||||
|
||||
```bash
|
||||
pytest test/
|
||||
pytest test/test_ppo.py
|
||||
```
|
||||
|
||||
## Status
|
||||
|
@ -1,12 +1,9 @@
|
||||
import torch
|
||||
import gymnasium as gym
|
||||
from torchrl.envs import GymEnv, TransformedEnv, Compose, RewardSum, StepCounter, ParallelEnv
|
||||
from torchrl.envs.libs.gym import GymWrapper
|
||||
from torchrl.record import VideoRecorder
|
||||
from abc import ABC
|
||||
import pdb
|
||||
from tensordict import TensorDict
|
||||
from torchrl.envs import GymWrapper, TransformedEnv
|
||||
from torchrl.envs import BatchSizeTransform
|
||||
|
||||
from fancy_rl.loggers import TerminalLogger
|
||||
|
||||
@ -50,33 +47,18 @@ class Algo(ABC):
|
||||
self.eval_episodes = eval_episodes
|
||||
|
||||
def make_env(self, eval=False):
|
||||
"""Creates an environment and wraps it if necessary."""
|
||||
env_spec = self.env_spec_eval if eval else self.env_spec
|
||||
env = self._wrap_env(env_spec)
|
||||
env.reset()
|
||||
return env
|
||||
|
||||
def _wrap_env(self, env_spec):
|
||||
if isinstance(env_spec, str):
|
||||
env = GymEnv(env_spec, device=self.device)
|
||||
elif isinstance(env_spec, gym.Env):
|
||||
env = GymWrapper(env_spec, device=self.device)
|
||||
elif isinstance(env_spec, GymEnv):
|
||||
env = env_spec
|
||||
env = gym.make(env_spec)
|
||||
env = GymWrapper(env).to(self.device)
|
||||
elif callable(env_spec):
|
||||
base_env = env_spec()
|
||||
return self._wrap_env(base_env)
|
||||
env = env_spec()
|
||||
if not (isinstance(env, gym.Env) or isinstance(env, gym.core.Wrapper)):
|
||||
raise ValueError("env_spec must be a string or a callable that returns an environment. Was a callable that returned a {}".format(type(env)))
|
||||
env = GymWrapper(env).to(self.device)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"env_spec must be a string, callable, Gymnasium environment, or GymEnv, "
|
||||
f"got {type(env_spec)}"
|
||||
)
|
||||
|
||||
if not env.batch_size:
|
||||
env = TransformedEnv(
|
||||
env,
|
||||
BatchSizeTransform(batch_size=torch.Size([1]))
|
||||
)
|
||||
|
||||
raise ValueError("env_spec must be a string or a callable that returns an environment. Was a {}".format(type(env_spec)))
|
||||
return env
|
||||
|
||||
def train_step(self, batch):
|
||||
@ -96,7 +78,7 @@ class Algo(ABC):
|
||||
):
|
||||
with torch.no_grad():
|
||||
obs_tensor = torch.as_tensor(observation, device=self.device).unsqueeze(0)
|
||||
td = TensorDict({"observation": obs_tensor})
|
||||
td = TensorDict({"observation": obs_tensor}, batch_size=[1])
|
||||
|
||||
action_td = self.prob_actor(td)
|
||||
action = action_td["action"]
|
||||
|
@ -2,9 +2,9 @@ import torch
|
||||
from torchrl.modules import ProbabilisticActor
|
||||
from torchrl.objectives import ClipPPOLoss
|
||||
from torchrl.objectives.value.advantages import GAE
|
||||
from torchrl.data.tensor_specs import DiscreteTensorSpec
|
||||
from fancy_rl.algos.on_policy import OnPolicy
|
||||
from fancy_rl.policy import Actor, Critic
|
||||
from fancy_rl.utils import is_discrete_space
|
||||
|
||||
class PPO(OnPolicy):
|
||||
def __init__(
|
||||
@ -41,25 +41,17 @@ class PPO(OnPolicy):
|
||||
# Initialize environment to get observation and action space sizes
|
||||
self.env_spec = env_spec
|
||||
env = self.make_env()
|
||||
|
||||
# Get spaces from specs for parallel env
|
||||
obs_space = env.observation_spec
|
||||
act_space = env.action_spec
|
||||
|
||||
self.discrete = isinstance(act_space, DiscreteTensorSpec)
|
||||
obs_space = env.observation_space
|
||||
act_space = env.action_space
|
||||
|
||||
self.discrete = is_discrete_space(act_space)
|
||||
|
||||
self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
|
||||
self.actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
|
||||
|
||||
if self.discrete:
|
||||
distribution_class = torch.distributions.Categorical
|
||||
self.prob_actor = ProbabilisticActor(
|
||||
module=self.actor,
|
||||
distribution_class=distribution_class,
|
||||
return_log_prob=True,
|
||||
in_keys=["logits"],
|
||||
out_keys=["action"],
|
||||
)
|
||||
distribution_kwargs = {"logits": "action_logits"}
|
||||
else:
|
||||
if full_covariance:
|
||||
distribution_class = torch.distributions.MultivariateNormal
|
||||
|
@ -1,7 +1,6 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from typing import Dict, Any, Optional
|
||||
from torchrl.data.tensor_specs import DiscreteTensorSpec
|
||||
from torchrl.modules import ProbabilisticActor, ValueOperator
|
||||
from torchrl.objectives import ClipPPOLoss
|
||||
from torchrl.collectors import SyncDataCollector
|
||||
@ -11,6 +10,7 @@ from fancy_rl.algos.on_policy import OnPolicy
|
||||
from fancy_rl.policy import Actor, Critic
|
||||
from fancy_rl.projections import get_projection, BaseProjection
|
||||
from fancy_rl.objectives import TRPLLoss
|
||||
from fancy_rl.utils import is_discrete_space
|
||||
from copy import deepcopy
|
||||
from tensordict.nn import TensorDictModule
|
||||
from tensordict import TensorDict
|
||||
@ -80,7 +80,7 @@ class TRPL(OnPolicy):
|
||||
obs_space = env.observation_space
|
||||
act_space = env.action_space
|
||||
|
||||
assert not isinstance(act_space, DiscreteTensorSpec), "TRPL does not support discrete action spaces"
|
||||
assert not is_discrete_space(act_space), "TRPL does not support discrete action spaces"
|
||||
|
||||
self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
|
||||
self.raw_actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
|
||||
|
@ -1,17 +1,14 @@
|
||||
import torch.nn as nn
|
||||
from tensordict.nn import TensorDictModule
|
||||
from torchrl.modules import MLP
|
||||
from torchrl.data.tensor_specs import DiscreteTensorSpec
|
||||
from tensordict.nn.distributions import NormalParamExtractor
|
||||
from fancy_rl.utils import is_discrete_space, get_space_shape
|
||||
from tensordict import TensorDict
|
||||
|
||||
class Actor(TensorDictModule):
|
||||
def __init__(self, obs_space, act_space, hidden_sizes, activation_fn, device, full_covariance=False):
|
||||
self.discrete = isinstance(act_space, DiscreteTensorSpec)
|
||||
|
||||
obs_space = obs_space["observation"]
|
||||
act_space_shape = act_space.shape[1:]
|
||||
obs_space_shape = obs_space.shape[1:]
|
||||
self.discrete = is_discrete_space(act_space)
|
||||
act_space_shape = get_space_shape(act_space)
|
||||
|
||||
if self.discrete and full_covariance:
|
||||
raise ValueError("Full covariance is not applicable for discrete action spaces.")
|
||||
@ -19,18 +16,18 @@ class Actor(TensorDictModule):
|
||||
self.full_covariance = full_covariance
|
||||
|
||||
if self.discrete:
|
||||
out_features = act_space_shape[0]
|
||||
out_keys = ["logits"]
|
||||
out_features = act_space_shape[-1]
|
||||
out_keys = ["action_logits"]
|
||||
else:
|
||||
if full_covariance:
|
||||
out_features = act_space_shape[0] + (act_space_shape[0] * (act_space_shape[0] + 1)) // 2
|
||||
out_features = act_space_shape[-1] + (act_space_shape[-1] * (act_space_shape[-1] + 1)) // 2
|
||||
out_keys = ["loc", "scale_tril"]
|
||||
else:
|
||||
out_features = act_space_shape[0] * 2
|
||||
out_features = act_space_shape[-1] * 2
|
||||
out_keys = ["loc", "scale"]
|
||||
|
||||
actor_module = MLP(
|
||||
in_features=obs_space_shape[0],
|
||||
in_features=get_space_shape(obs_space)[-1],
|
||||
out_features=out_features,
|
||||
num_cells=hidden_sizes,
|
||||
activation_class=getattr(nn, activation_fn),
|
||||
@ -39,7 +36,7 @@ class Actor(TensorDictModule):
|
||||
|
||||
if not self.discrete:
|
||||
if full_covariance:
|
||||
param_extractor = FullCovarianceNormalParamExtractor(act_space_shape[0])
|
||||
param_extractor = FullCovarianceNormalParamExtractor(act_space_shape[-1])
|
||||
else:
|
||||
param_extractor = NormalParamExtractor()
|
||||
actor_module = nn.Sequential(actor_module, param_extractor)
|
||||
@ -66,7 +63,7 @@ class FullCovarianceNormalParamExtractor(nn.Module):
|
||||
class Critic(TensorDictModule):
|
||||
def __init__(self, obs_space, hidden_sizes, activation_fn, device):
|
||||
critic_module = MLP(
|
||||
in_features=obs_space.shape[-1],
|
||||
in_features=get_space_shape(obs_space)[-1],
|
||||
out_features=1,
|
||||
num_cells=hidden_sizes,
|
||||
activation_class=getattr(nn, activation_fn),
|
||||
|
@ -0,0 +1,61 @@
|
||||
import gymnasium
|
||||
from gymnasium.spaces import Discrete as GymnasiumDiscrete, MultiDiscrete as GymnasiumMultiDiscrete, MultiBinary as GymnasiumMultiBinary, Box as GymnasiumBox
|
||||
from torchrl.data.tensor_specs import (
|
||||
DiscreteTensorSpec, OneHotDiscreteTensorSpec, MultiDiscreteTensorSpec,
|
||||
BinaryDiscreteTensorSpec, BoundedTensorSpec, UnboundedContinuousTensorSpec
|
||||
)
|
||||
|
||||
try:
|
||||
import gym
|
||||
from gym.spaces import Discrete as GymDiscrete, MultiDiscrete as GymMultiDiscrete, MultiBinary as GymMultiBinary, Box as GymBox
|
||||
gym_available = True
|
||||
except ImportError:
|
||||
gym_available = False
|
||||
|
||||
def is_discrete_space(action_space):
|
||||
discrete_types = (
|
||||
GymnasiumDiscrete, GymnasiumMultiDiscrete, GymnasiumMultiBinary,
|
||||
DiscreteTensorSpec, OneHotDiscreteTensorSpec, MultiDiscreteTensorSpec, BinaryDiscreteTensorSpec
|
||||
)
|
||||
continuous_types = (
|
||||
GymnasiumBox, BoundedTensorSpec, UnboundedContinuousTensorSpec
|
||||
)
|
||||
|
||||
if gym_available:
|
||||
discrete_types += (GymDiscrete, GymMultiDiscrete, GymMultiBinary)
|
||||
continuous_types += (GymBox,)
|
||||
|
||||
if isinstance(action_space, discrete_types):
|
||||
return True
|
||||
elif isinstance(action_space, continuous_types):
|
||||
return False
|
||||
else:
|
||||
raise ValueError(f"Unsupported action space type: {type(action_space)}")
|
||||
|
||||
def get_space_shape(action_space):
|
||||
discrete_types = (GymnasiumDiscrete, GymnasiumMultiDiscrete, GymnasiumMultiBinary,
|
||||
DiscreteTensorSpec, OneHotDiscreteTensorSpec, MultiDiscreteTensorSpec, BinaryDiscreteTensorSpec)
|
||||
continuous_types = (GymnasiumBox, BoundedTensorSpec, UnboundedContinuousTensorSpec)
|
||||
|
||||
if gym_available:
|
||||
discrete_types += (GymDiscrete, GymMultiDiscrete, GymMultiBinary)
|
||||
continuous_types += (GymBox,)
|
||||
|
||||
if isinstance(action_space, discrete_types):
|
||||
if isinstance(action_space, (GymnasiumDiscrete, DiscreteTensorSpec, OneHotDiscreteTensorSpec)):
|
||||
return (action_space.n,)
|
||||
elif isinstance(action_space, (GymnasiumMultiDiscrete, MultiDiscreteTensorSpec)):
|
||||
return (sum(action_space.nvec),)
|
||||
elif isinstance(action_space, (GymnasiumMultiBinary, BinaryDiscreteTensorSpec)):
|
||||
return (action_space.n,)
|
||||
elif gym_available:
|
||||
if isinstance(action_space, GymDiscrete):
|
||||
return (action_space.n,)
|
||||
elif isinstance(action_space, GymMultiDiscrete):
|
||||
return (sum(action_space.nvec),)
|
||||
elif isinstance(action_space, GymMultiBinary):
|
||||
return (action_space.n,)
|
||||
elif isinstance(action_space, continuous_types):
|
||||
return action_space.shape
|
||||
|
||||
raise ValueError(f"Unsupported action space type: {type(action_space)}")
|
@ -2,10 +2,9 @@ import pytest
|
||||
import numpy as np
|
||||
from fancy_rl import PPO
|
||||
import gymnasium as gym
|
||||
from torchrl.envs import GymEnv
|
||||
|
||||
def simple_env():
|
||||
return GymEnv('LunarLander-v2', continuous=True)
|
||||
return gym.make('LunarLander-v2', continuous=True)
|
||||
|
||||
def test_ppo_instantiation():
|
||||
ppo = PPO(simple_env)
|
||||
@ -15,10 +14,6 @@ def test_ppo_instantiation_from_str():
|
||||
ppo = PPO('CartPole-v1')
|
||||
assert isinstance(ppo, PPO)
|
||||
|
||||
def test_ppo_instantiation_from_make():
|
||||
ppo = PPO(gym.make('CartPole-v1'))
|
||||
assert isinstance(ppo, PPO)
|
||||
|
||||
@pytest.mark.parametrize("learning_rate", [1e-4, 3e-4, 1e-3])
|
||||
@pytest.mark.parametrize("n_steps", [1024, 2048])
|
||||
@pytest.mark.parametrize("batch_size", [32, 64, 128])
|
||||
@ -53,12 +48,12 @@ def test_ppo_predict():
|
||||
def test_ppo_learn():
|
||||
ppo = PPO(simple_env, n_steps=64, batch_size=32)
|
||||
env = ppo.make_env()
|
||||
obs = env.reset()
|
||||
obs, _ = env.reset()
|
||||
for _ in range(64):
|
||||
action, _next_state = ppo.predict(obs)
|
||||
action, _ = ppo.predict(obs)
|
||||
obs, reward, done, truncated, _ = env.step(action)
|
||||
if done or truncated:
|
||||
obs = env.reset()
|
||||
obs, _ = env.reset()
|
||||
|
||||
def test_ppo_training():
|
||||
ppo = PPO(simple_env, total_timesteps=10000)
|
||||
@ -73,10 +68,10 @@ def test_ppo_training():
|
||||
def evaluate_policy(policy, env, n_eval_episodes=10):
|
||||
total_reward = 0
|
||||
for _ in range(n_eval_episodes):
|
||||
obs = env.reset()
|
||||
obs, _ = env.reset()
|
||||
done = False
|
||||
while not done:
|
||||
action, _next_state = policy.predict(obs)
|
||||
action, _ = policy.predict(obs)
|
||||
obs, reward, terminated, truncated, _ = env.step(action)
|
||||
total_reward += reward
|
||||
done = terminated or truncated
|
||||
|
Loading…
Reference in New Issue
Block a user