Compare commits
5 Commits
df1ba6fe53
...
e938018494
Author | SHA1 | Date | |
---|---|---|---|
e938018494 | |||
4f8fc500b7 | |||
5c44448e53 | |||
8a078fb59e | |||
52b3f3b71e |
@ -49,7 +49,7 @@ The TRPL implementation in Fancy RL includes projections based on the Kullback-L
|
|||||||
To run the test suite:
|
To run the test suite:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pytest test/test_ppo.py
|
pytest test/
|
||||||
```
|
```
|
||||||
|
|
||||||
## Status
|
## Status
|
||||||
|
@ -1,9 +1,12 @@
|
|||||||
import torch
|
import torch
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
from torchrl.envs.libs.gym import GymWrapper
|
from torchrl.envs import GymEnv, TransformedEnv, Compose, RewardSum, StepCounter, ParallelEnv
|
||||||
from torchrl.record import VideoRecorder
|
from torchrl.record import VideoRecorder
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
|
import pdb
|
||||||
from tensordict import TensorDict
|
from tensordict import TensorDict
|
||||||
|
from torchrl.envs import GymWrapper, TransformedEnv
|
||||||
|
from torchrl.envs import BatchSizeTransform
|
||||||
|
|
||||||
from fancy_rl.loggers import TerminalLogger
|
from fancy_rl.loggers import TerminalLogger
|
||||||
|
|
||||||
@ -47,18 +50,33 @@ class Algo(ABC):
|
|||||||
self.eval_episodes = eval_episodes
|
self.eval_episodes = eval_episodes
|
||||||
|
|
||||||
def make_env(self, eval=False):
|
def make_env(self, eval=False):
|
||||||
"""Creates an environment and wraps it if necessary."""
|
|
||||||
env_spec = self.env_spec_eval if eval else self.env_spec
|
env_spec = self.env_spec_eval if eval else self.env_spec
|
||||||
|
env = self._wrap_env(env_spec)
|
||||||
|
env.reset()
|
||||||
|
return env
|
||||||
|
|
||||||
|
def _wrap_env(self, env_spec):
|
||||||
if isinstance(env_spec, str):
|
if isinstance(env_spec, str):
|
||||||
env = gym.make(env_spec)
|
env = GymEnv(env_spec, device=self.device)
|
||||||
env = GymWrapper(env).to(self.device)
|
elif isinstance(env_spec, gym.Env):
|
||||||
|
env = GymWrapper(env_spec, device=self.device)
|
||||||
|
elif isinstance(env_spec, GymEnv):
|
||||||
|
env = env_spec
|
||||||
elif callable(env_spec):
|
elif callable(env_spec):
|
||||||
env = env_spec()
|
base_env = env_spec()
|
||||||
if not (isinstance(env, gym.Env) or isinstance(env, gym.core.Wrapper)):
|
return self._wrap_env(base_env)
|
||||||
raise ValueError("env_spec must be a string or a callable that returns an environment. Was a callable that returned a {}".format(type(env)))
|
|
||||||
env = GymWrapper(env).to(self.device)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("env_spec must be a string or a callable that returns an environment. Was a {}".format(type(env_spec)))
|
raise ValueError(
|
||||||
|
f"env_spec must be a string, callable, Gymnasium environment, or GymEnv, "
|
||||||
|
f"got {type(env_spec)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not env.batch_size:
|
||||||
|
env = TransformedEnv(
|
||||||
|
env,
|
||||||
|
BatchSizeTransform(batch_size=torch.Size([1]))
|
||||||
|
)
|
||||||
|
|
||||||
return env
|
return env
|
||||||
|
|
||||||
def train_step(self, batch):
|
def train_step(self, batch):
|
||||||
@ -78,7 +96,7 @@ class Algo(ABC):
|
|||||||
):
|
):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
obs_tensor = torch.as_tensor(observation, device=self.device).unsqueeze(0)
|
obs_tensor = torch.as_tensor(observation, device=self.device).unsqueeze(0)
|
||||||
td = TensorDict({"observation": obs_tensor}, batch_size=[1])
|
td = TensorDict({"observation": obs_tensor})
|
||||||
|
|
||||||
action_td = self.prob_actor(td)
|
action_td = self.prob_actor(td)
|
||||||
action = action_td["action"]
|
action = action_td["action"]
|
||||||
|
@ -2,9 +2,9 @@ import torch
|
|||||||
from torchrl.modules import ProbabilisticActor
|
from torchrl.modules import ProbabilisticActor
|
||||||
from torchrl.objectives import ClipPPOLoss
|
from torchrl.objectives import ClipPPOLoss
|
||||||
from torchrl.objectives.value.advantages import GAE
|
from torchrl.objectives.value.advantages import GAE
|
||||||
|
from torchrl.data.tensor_specs import DiscreteTensorSpec
|
||||||
from fancy_rl.algos.on_policy import OnPolicy
|
from fancy_rl.algos.on_policy import OnPolicy
|
||||||
from fancy_rl.policy import Actor, Critic
|
from fancy_rl.policy import Actor, Critic
|
||||||
from fancy_rl.utils import is_discrete_space
|
|
||||||
|
|
||||||
class PPO(OnPolicy):
|
class PPO(OnPolicy):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -41,17 +41,25 @@ class PPO(OnPolicy):
|
|||||||
# Initialize environment to get observation and action space sizes
|
# Initialize environment to get observation and action space sizes
|
||||||
self.env_spec = env_spec
|
self.env_spec = env_spec
|
||||||
env = self.make_env()
|
env = self.make_env()
|
||||||
obs_space = env.observation_space
|
|
||||||
act_space = env.action_space
|
|
||||||
|
|
||||||
self.discrete = is_discrete_space(act_space)
|
# Get spaces from specs for parallel env
|
||||||
|
obs_space = env.observation_spec
|
||||||
|
act_space = env.action_spec
|
||||||
|
|
||||||
|
self.discrete = isinstance(act_space, DiscreteTensorSpec)
|
||||||
|
|
||||||
self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
|
self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
|
||||||
self.actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
|
self.actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
|
||||||
|
|
||||||
if self.discrete:
|
if self.discrete:
|
||||||
distribution_class = torch.distributions.Categorical
|
distribution_class = torch.distributions.Categorical
|
||||||
distribution_kwargs = {"logits": "action_logits"}
|
self.prob_actor = ProbabilisticActor(
|
||||||
|
module=self.actor,
|
||||||
|
distribution_class=distribution_class,
|
||||||
|
return_log_prob=True,
|
||||||
|
in_keys=["logits"],
|
||||||
|
out_keys=["action"],
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if full_covariance:
|
if full_covariance:
|
||||||
distribution_class = torch.distributions.MultivariateNormal
|
distribution_class = torch.distributions.MultivariateNormal
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from typing import Dict, Any, Optional
|
from typing import Dict, Any, Optional
|
||||||
|
from torchrl.data.tensor_specs import DiscreteTensorSpec
|
||||||
from torchrl.modules import ProbabilisticActor, ValueOperator
|
from torchrl.modules import ProbabilisticActor, ValueOperator
|
||||||
from torchrl.objectives import ClipPPOLoss
|
from torchrl.objectives import ClipPPOLoss
|
||||||
from torchrl.collectors import SyncDataCollector
|
from torchrl.collectors import SyncDataCollector
|
||||||
@ -10,7 +11,6 @@ from fancy_rl.algos.on_policy import OnPolicy
|
|||||||
from fancy_rl.policy import Actor, Critic
|
from fancy_rl.policy import Actor, Critic
|
||||||
from fancy_rl.projections import get_projection, BaseProjection
|
from fancy_rl.projections import get_projection, BaseProjection
|
||||||
from fancy_rl.objectives import TRPLLoss
|
from fancy_rl.objectives import TRPLLoss
|
||||||
from fancy_rl.utils import is_discrete_space
|
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from tensordict.nn import TensorDictModule
|
from tensordict.nn import TensorDictModule
|
||||||
from tensordict import TensorDict
|
from tensordict import TensorDict
|
||||||
@ -80,7 +80,7 @@ class TRPL(OnPolicy):
|
|||||||
obs_space = env.observation_space
|
obs_space = env.observation_space
|
||||||
act_space = env.action_space
|
act_space = env.action_space
|
||||||
|
|
||||||
assert not is_discrete_space(act_space), "TRPL does not support discrete action spaces"
|
assert not isinstance(act_space, DiscreteTensorSpec), "TRPL does not support discrete action spaces"
|
||||||
|
|
||||||
self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
|
self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
|
||||||
self.raw_actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
|
self.raw_actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
|
||||||
|
@ -1,14 +1,17 @@
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from tensordict.nn import TensorDictModule
|
from tensordict.nn import TensorDictModule
|
||||||
from torchrl.modules import MLP
|
from torchrl.modules import MLP
|
||||||
|
from torchrl.data.tensor_specs import DiscreteTensorSpec
|
||||||
from tensordict.nn.distributions import NormalParamExtractor
|
from tensordict.nn.distributions import NormalParamExtractor
|
||||||
from fancy_rl.utils import is_discrete_space, get_space_shape
|
|
||||||
from tensordict import TensorDict
|
from tensordict import TensorDict
|
||||||
|
|
||||||
class Actor(TensorDictModule):
|
class Actor(TensorDictModule):
|
||||||
def __init__(self, obs_space, act_space, hidden_sizes, activation_fn, device, full_covariance=False):
|
def __init__(self, obs_space, act_space, hidden_sizes, activation_fn, device, full_covariance=False):
|
||||||
self.discrete = is_discrete_space(act_space)
|
self.discrete = isinstance(act_space, DiscreteTensorSpec)
|
||||||
act_space_shape = get_space_shape(act_space)
|
|
||||||
|
obs_space = obs_space["observation"]
|
||||||
|
act_space_shape = act_space.shape[1:]
|
||||||
|
obs_space_shape = obs_space.shape[1:]
|
||||||
|
|
||||||
if self.discrete and full_covariance:
|
if self.discrete and full_covariance:
|
||||||
raise ValueError("Full covariance is not applicable for discrete action spaces.")
|
raise ValueError("Full covariance is not applicable for discrete action spaces.")
|
||||||
@ -16,18 +19,18 @@ class Actor(TensorDictModule):
|
|||||||
self.full_covariance = full_covariance
|
self.full_covariance = full_covariance
|
||||||
|
|
||||||
if self.discrete:
|
if self.discrete:
|
||||||
out_features = act_space_shape[-1]
|
out_features = act_space_shape[0]
|
||||||
out_keys = ["action_logits"]
|
out_keys = ["logits"]
|
||||||
else:
|
else:
|
||||||
if full_covariance:
|
if full_covariance:
|
||||||
out_features = act_space_shape[-1] + (act_space_shape[-1] * (act_space_shape[-1] + 1)) // 2
|
out_features = act_space_shape[0] + (act_space_shape[0] * (act_space_shape[0] + 1)) // 2
|
||||||
out_keys = ["loc", "scale_tril"]
|
out_keys = ["loc", "scale_tril"]
|
||||||
else:
|
else:
|
||||||
out_features = act_space_shape[-1] * 2
|
out_features = act_space_shape[0] * 2
|
||||||
out_keys = ["loc", "scale"]
|
out_keys = ["loc", "scale"]
|
||||||
|
|
||||||
actor_module = MLP(
|
actor_module = MLP(
|
||||||
in_features=get_space_shape(obs_space)[-1],
|
in_features=obs_space_shape[0],
|
||||||
out_features=out_features,
|
out_features=out_features,
|
||||||
num_cells=hidden_sizes,
|
num_cells=hidden_sizes,
|
||||||
activation_class=getattr(nn, activation_fn),
|
activation_class=getattr(nn, activation_fn),
|
||||||
@ -36,7 +39,7 @@ class Actor(TensorDictModule):
|
|||||||
|
|
||||||
if not self.discrete:
|
if not self.discrete:
|
||||||
if full_covariance:
|
if full_covariance:
|
||||||
param_extractor = FullCovarianceNormalParamExtractor(act_space_shape[-1])
|
param_extractor = FullCovarianceNormalParamExtractor(act_space_shape[0])
|
||||||
else:
|
else:
|
||||||
param_extractor = NormalParamExtractor()
|
param_extractor = NormalParamExtractor()
|
||||||
actor_module = nn.Sequential(actor_module, param_extractor)
|
actor_module = nn.Sequential(actor_module, param_extractor)
|
||||||
@ -63,7 +66,7 @@ class FullCovarianceNormalParamExtractor(nn.Module):
|
|||||||
class Critic(TensorDictModule):
|
class Critic(TensorDictModule):
|
||||||
def __init__(self, obs_space, hidden_sizes, activation_fn, device):
|
def __init__(self, obs_space, hidden_sizes, activation_fn, device):
|
||||||
critic_module = MLP(
|
critic_module = MLP(
|
||||||
in_features=get_space_shape(obs_space)[-1],
|
in_features=obs_space.shape[-1],
|
||||||
out_features=1,
|
out_features=1,
|
||||||
num_cells=hidden_sizes,
|
num_cells=hidden_sizes,
|
||||||
activation_class=getattr(nn, activation_fn),
|
activation_class=getattr(nn, activation_fn),
|
||||||
|
@ -1,61 +0,0 @@
|
|||||||
import gymnasium
|
|
||||||
from gymnasium.spaces import Discrete as GymnasiumDiscrete, MultiDiscrete as GymnasiumMultiDiscrete, MultiBinary as GymnasiumMultiBinary, Box as GymnasiumBox
|
|
||||||
from torchrl.data.tensor_specs import (
|
|
||||||
DiscreteTensorSpec, OneHotDiscreteTensorSpec, MultiDiscreteTensorSpec,
|
|
||||||
BinaryDiscreteTensorSpec, BoundedTensorSpec, UnboundedContinuousTensorSpec
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
import gym
|
|
||||||
from gym.spaces import Discrete as GymDiscrete, MultiDiscrete as GymMultiDiscrete, MultiBinary as GymMultiBinary, Box as GymBox
|
|
||||||
gym_available = True
|
|
||||||
except ImportError:
|
|
||||||
gym_available = False
|
|
||||||
|
|
||||||
def is_discrete_space(action_space):
|
|
||||||
discrete_types = (
|
|
||||||
GymnasiumDiscrete, GymnasiumMultiDiscrete, GymnasiumMultiBinary,
|
|
||||||
DiscreteTensorSpec, OneHotDiscreteTensorSpec, MultiDiscreteTensorSpec, BinaryDiscreteTensorSpec
|
|
||||||
)
|
|
||||||
continuous_types = (
|
|
||||||
GymnasiumBox, BoundedTensorSpec, UnboundedContinuousTensorSpec
|
|
||||||
)
|
|
||||||
|
|
||||||
if gym_available:
|
|
||||||
discrete_types += (GymDiscrete, GymMultiDiscrete, GymMultiBinary)
|
|
||||||
continuous_types += (GymBox,)
|
|
||||||
|
|
||||||
if isinstance(action_space, discrete_types):
|
|
||||||
return True
|
|
||||||
elif isinstance(action_space, continuous_types):
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported action space type: {type(action_space)}")
|
|
||||||
|
|
||||||
def get_space_shape(action_space):
|
|
||||||
discrete_types = (GymnasiumDiscrete, GymnasiumMultiDiscrete, GymnasiumMultiBinary,
|
|
||||||
DiscreteTensorSpec, OneHotDiscreteTensorSpec, MultiDiscreteTensorSpec, BinaryDiscreteTensorSpec)
|
|
||||||
continuous_types = (GymnasiumBox, BoundedTensorSpec, UnboundedContinuousTensorSpec)
|
|
||||||
|
|
||||||
if gym_available:
|
|
||||||
discrete_types += (GymDiscrete, GymMultiDiscrete, GymMultiBinary)
|
|
||||||
continuous_types += (GymBox,)
|
|
||||||
|
|
||||||
if isinstance(action_space, discrete_types):
|
|
||||||
if isinstance(action_space, (GymnasiumDiscrete, DiscreteTensorSpec, OneHotDiscreteTensorSpec)):
|
|
||||||
return (action_space.n,)
|
|
||||||
elif isinstance(action_space, (GymnasiumMultiDiscrete, MultiDiscreteTensorSpec)):
|
|
||||||
return (sum(action_space.nvec),)
|
|
||||||
elif isinstance(action_space, (GymnasiumMultiBinary, BinaryDiscreteTensorSpec)):
|
|
||||||
return (action_space.n,)
|
|
||||||
elif gym_available:
|
|
||||||
if isinstance(action_space, GymDiscrete):
|
|
||||||
return (action_space.n,)
|
|
||||||
elif isinstance(action_space, GymMultiDiscrete):
|
|
||||||
return (sum(action_space.nvec),)
|
|
||||||
elif isinstance(action_space, GymMultiBinary):
|
|
||||||
return (action_space.n,)
|
|
||||||
elif isinstance(action_space, continuous_types):
|
|
||||||
return action_space.shape
|
|
||||||
|
|
||||||
raise ValueError(f"Unsupported action space type: {type(action_space)}")
|
|
@ -2,9 +2,10 @@ import pytest
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from fancy_rl import PPO
|
from fancy_rl import PPO
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
|
from torchrl.envs import GymEnv
|
||||||
|
|
||||||
def simple_env():
|
def simple_env():
|
||||||
return gym.make('LunarLander-v2', continuous=True)
|
return GymEnv('LunarLander-v2', continuous=True)
|
||||||
|
|
||||||
def test_ppo_instantiation():
|
def test_ppo_instantiation():
|
||||||
ppo = PPO(simple_env)
|
ppo = PPO(simple_env)
|
||||||
@ -14,6 +15,10 @@ def test_ppo_instantiation_from_str():
|
|||||||
ppo = PPO('CartPole-v1')
|
ppo = PPO('CartPole-v1')
|
||||||
assert isinstance(ppo, PPO)
|
assert isinstance(ppo, PPO)
|
||||||
|
|
||||||
|
def test_ppo_instantiation_from_make():
|
||||||
|
ppo = PPO(gym.make('CartPole-v1'))
|
||||||
|
assert isinstance(ppo, PPO)
|
||||||
|
|
||||||
@pytest.mark.parametrize("learning_rate", [1e-4, 3e-4, 1e-3])
|
@pytest.mark.parametrize("learning_rate", [1e-4, 3e-4, 1e-3])
|
||||||
@pytest.mark.parametrize("n_steps", [1024, 2048])
|
@pytest.mark.parametrize("n_steps", [1024, 2048])
|
||||||
@pytest.mark.parametrize("batch_size", [32, 64, 128])
|
@pytest.mark.parametrize("batch_size", [32, 64, 128])
|
||||||
@ -48,12 +53,12 @@ def test_ppo_predict():
|
|||||||
def test_ppo_learn():
|
def test_ppo_learn():
|
||||||
ppo = PPO(simple_env, n_steps=64, batch_size=32)
|
ppo = PPO(simple_env, n_steps=64, batch_size=32)
|
||||||
env = ppo.make_env()
|
env = ppo.make_env()
|
||||||
obs, _ = env.reset()
|
obs = env.reset()
|
||||||
for _ in range(64):
|
for _ in range(64):
|
||||||
action, _ = ppo.predict(obs)
|
action, _next_state = ppo.predict(obs)
|
||||||
obs, reward, done, truncated, _ = env.step(action)
|
obs, reward, done, truncated, _ = env.step(action)
|
||||||
if done or truncated:
|
if done or truncated:
|
||||||
obs, _ = env.reset()
|
obs = env.reset()
|
||||||
|
|
||||||
def test_ppo_training():
|
def test_ppo_training():
|
||||||
ppo = PPO(simple_env, total_timesteps=10000)
|
ppo = PPO(simple_env, total_timesteps=10000)
|
||||||
@ -68,10 +73,10 @@ def test_ppo_training():
|
|||||||
def evaluate_policy(policy, env, n_eval_episodes=10):
|
def evaluate_policy(policy, env, n_eval_episodes=10):
|
||||||
total_reward = 0
|
total_reward = 0
|
||||||
for _ in range(n_eval_episodes):
|
for _ in range(n_eval_episodes):
|
||||||
obs, _ = env.reset()
|
obs = env.reset()
|
||||||
done = False
|
done = False
|
||||||
while not done:
|
while not done:
|
||||||
action, _ = policy.predict(obs)
|
action, _next_state = policy.predict(obs)
|
||||||
obs, reward, terminated, truncated, _ = env.step(action)
|
obs, reward, terminated, truncated, _ = env.step(action)
|
||||||
total_reward += reward
|
total_reward += reward
|
||||||
done = terminated or truncated
|
done = terminated or truncated
|
||||||
|
Loading…
Reference in New Issue
Block a user