Compare commits

..

No commits in common. "3816adef9ae9c33c968d7ec4908ab5879559edfb" and "e938018494a1a5b512eb2a40ae9c4d5155d81299" have entirely different histories.

9 changed files with 121 additions and 147 deletions

1
.gitignore vendored
View File

@ -1,6 +1,5 @@
__pycache__ __pycache__
.venv .venv
.vscode
wandb wandb
*.egg-info/ *.egg-info/
test.py test.py

View File

@ -1,10 +1,9 @@
import torch import torch
import gymnasium as gym import gymnasium as gym
from torchrl.envs import GymEnv, TransformedEnv, Compose, RewardSum, StepCounter, SerialEnv from torchrl.envs import GymEnv, TransformedEnv, Compose, RewardSum, StepCounter, ParallelEnv
from torchrl.record import VideoRecorder from torchrl.record import VideoRecorder
from abc import ABC from abc import ABC
import pdb import pdb
import numpy as np
from tensordict import TensorDict from tensordict import TensorDict
from torchrl.envs import GymWrapper, TransformedEnv from torchrl.envs import GymWrapper, TransformedEnv
from torchrl.envs import BatchSizeTransform from torchrl.envs import BatchSizeTransform
@ -57,31 +56,27 @@ class Algo(ABC):
return env return env
def _wrap_env(self, env_spec): def _wrap_env(self, env_spec):
# If given an existing env, ensure it's properly batched
if isinstance(env_spec, (GymEnv, GymWrapper)):
if not env_spec.batch_size:
raise ValueError("Environment must be batched")
return env_spec
# Handle callable without wrapping the recursive call
if callable(env_spec):
return self._wrap_env(env_spec())
# Create new batched environment using SerialEnv
if isinstance(env_spec, str): if isinstance(env_spec, str):
env = SerialEnv(1, lambda: GymEnv(env_spec, device=self.device)) env = GymEnv(env_spec, device=self.device)
elif isinstance(env_spec, gym.Env): elif isinstance(env_spec, gym.Env):
wrapped_env = GymWrapper(env_spec, device=self.device) env = GymWrapper(env_spec, device=self.device)
if wrapped_env.batch_size: elif isinstance(env_spec, GymEnv):
env = wrapped_env env = env_spec
else: elif callable(env_spec):
env = SerialEnv(1, lambda: wrapped_env) base_env = env_spec()
return self._wrap_env(base_env)
else: else:
raise ValueError( raise ValueError(
f"env_spec must be a string, callable, Gymnasium environment, or GymEnv, " f"env_spec must be a string, callable, Gymnasium environment, or GymEnv, "
f"got {type(env_spec)}" f"got {type(env_spec)}"
) )
if not env.batch_size:
env = TransformedEnv(
env,
BatchSizeTransform(batch_size=torch.Size([1]))
)
return env return env
def train_step(self, batch): def train_step(self, batch):
@ -95,21 +90,18 @@ class Algo(ABC):
def predict( def predict(
self, self,
tensordict, observation,
state=None, state=None,
deterministic=False deterministic=False
): ):
with torch.no_grad(): with torch.no_grad():
# If numpy array, convert to TensorDict obs_tensor = torch.as_tensor(observation, device=self.device).unsqueeze(0)
if isinstance(tensordict, np.ndarray): td = TensorDict({"observation": obs_tensor})
tensordict = TensorDict(
{"observation": torch.from_numpy(tensordict).float()},
batch_size=[]
)
# Move to device action_td = self.prob_actor(td)
tensordict = tensordict.to(self.device) action = action_td["action"]
# Get action from policy # We're not using recurrent policies, so we'll always return None for the state
action_td = self.prob_actor(tensordict) next_state = None
return action_td
return action.squeeze(0).cpu().numpy(), next_state

View File

@ -43,13 +43,13 @@ class PPO(OnPolicy):
env = self.make_env() env = self.make_env()
# Get spaces from specs for parallel env # Get spaces from specs for parallel env
self.obs_space = env.observation_spec obs_space = env.observation_spec
self.act_space = env.action_spec act_space = env.action_spec
self.discrete = isinstance(self.act_space, DiscreteTensorSpec) self.discrete = isinstance(act_space, DiscreteTensorSpec)
self.critic = Critic(self.obs_space, critic_hidden_sizes, critic_activation_fn, device) self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
self.actor = Actor(self.obs_space, self.act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance) self.actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
if self.discrete: if self.discrete:
distribution_class = torch.distributions.Categorical distribution_class = torch.distributions.Categorical

View File

@ -14,8 +14,6 @@ from fancy_rl.objectives import TRPLLoss
from copy import deepcopy from copy import deepcopy
from tensordict.nn import TensorDictModule from tensordict.nn import TensorDictModule
from tensordict import TensorDict from tensordict import TensorDict
from torch.distributions import Categorical, MultivariateNormal, Normal
class ProjectedActor(TensorDictModule): class ProjectedActor(TensorDictModule):
def __init__(self, raw_actor, old_actor, projection): def __init__(self, raw_actor, old_actor, projection):
@ -28,8 +26,6 @@ class ProjectedActor(TensorDictModule):
self.raw_actor = raw_actor self.raw_actor = raw_actor
self.old_actor = old_actor self.old_actor = old_actor
self.projection = projection self.projection = projection
self.discrete = raw_actor.discrete
self.full_covariance = raw_actor.full_covariance
class CombinedModule(nn.Module): class CombinedModule(nn.Module):
def __init__(self, raw_actor, old_actor, projection): def __init__(self, raw_actor, old_actor, projection):
@ -39,41 +35,12 @@ class ProjectedActor(TensorDictModule):
self.projection = projection self.projection = projection
def forward(self, tensordict): def forward(self, tensordict):
# Convert the tuple outputs to TensorDict
raw_params = self.raw_actor(tensordict) raw_params = self.raw_actor(tensordict)
if isinstance(raw_params, tuple):
raw_params = TensorDict({
"loc": raw_params[0],
"scale": raw_params[1]
}, batch_size=[raw_params[0].shape[0]]) # Use the first dimension of the tensor as batch size
old_params = self.old_actor(tensordict) old_params = self.old_actor(tensordict)
if isinstance(old_params, tuple): combined_params = TensorDict({**raw_params, **{f"old_{key}": value for key, value in old_params.items()}}, batch_size=tensordict.batch_size)
old_params = TensorDict({
"loc": old_params[0],
"scale": old_params[1]
}, batch_size=[old_params[0].shape[0]]) # Use the first dimension of the tensor as batch size
# Now combine them
combined_params = TensorDict({
**raw_params,
**{f"old_{key}": value for key, value in old_params.items()}
}, batch_size=[raw_params["loc"].shape[0]]) # Use the first dimension of loc tensor as batch size
projected_params = self.projection(combined_params) projected_params = self.projection(combined_params)
return projected_params return projected_params
def get_dist(self, tensordict):
# Forward the observation through the network
out = self.forward(tensordict)
if self.discrete:
return Categorical(logits=out["logits"])
else:
if self.full_covariance:
return MultivariateNormal(loc=out["loc"], scale_tril=out["scale_tril"])
else:
return Normal(loc=out["loc"], scale=out["scale"])
class TRPL(OnPolicy): class TRPL(OnPolicy):
def __init__( def __init__(
self, self,
@ -110,16 +77,14 @@ class TRPL(OnPolicy):
# Initialize environment to get observation and action space sizes # Initialize environment to get observation and action space sizes
self.env_spec = env_spec self.env_spec = env_spec
env = self.make_env() env = self.make_env()
obs_space = env.observation_space
act_space = env.action_space
# Get spaces from specs for parallel env assert not isinstance(act_space, DiscreteTensorSpec), "TRPL does not support discrete action spaces"
self.obs_space = env.observation_spec
self.act_space = env.action_spec
assert not isinstance(self.act_space, DiscreteTensorSpec), "TRPL does not support discrete action spaces" self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
self.raw_actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
self.critic = Critic(self.obs_space, critic_hidden_sizes, critic_activation_fn, device) self.old_actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
self.raw_actor = Actor(self.obs_space, self.act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
self.old_actor = Actor(self.obs_space, self.act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
# Handle projection_class # Handle projection_class
if isinstance(projection_class, str): if isinstance(projection_class, str):

View File

@ -4,7 +4,6 @@ from torchrl.modules import MLP
from torchrl.data.tensor_specs import DiscreteTensorSpec from torchrl.data.tensor_specs import DiscreteTensorSpec
from tensordict.nn.distributions import NormalParamExtractor from tensordict.nn.distributions import NormalParamExtractor
from tensordict import TensorDict from tensordict import TensorDict
from torch.distributions import Categorical, MultivariateNormal, Normal
class Actor(TensorDictModule): class Actor(TensorDictModule):
def __init__(self, obs_space, act_space, hidden_sizes, activation_fn, device, full_covariance=False): def __init__(self, obs_space, act_space, hidden_sizes, activation_fn, device, full_covariance=False):
@ -51,17 +50,6 @@ class Actor(TensorDictModule):
out_keys=out_keys out_keys=out_keys
) )
def get_dist(self, tensordict):
# Forward the observation through the network
out = self.forward(tensordict)
if self.discrete:
return Categorical(logits=out["logits"])
else:
if self.full_covariance:
return MultivariateNormal(loc=out["loc"], scale_tril=out["scale_tril"])
else:
return Normal(loc=out["loc"], scale=out["scale"])
class FullCovarianceNormalParamExtractor(nn.Module): class FullCovarianceNormalParamExtractor(nn.Module):
def __init__(self, action_dim): def __init__(self, action_dim):
super().__init__() super().__init__()
@ -77,11 +65,8 @@ class FullCovarianceNormalParamExtractor(nn.Module):
class Critic(TensorDictModule): class Critic(TensorDictModule):
def __init__(self, obs_space, hidden_sizes, activation_fn, device): def __init__(self, obs_space, hidden_sizes, activation_fn, device):
obs_space = obs_space["observation"]
obs_space_shape = obs_space.shape[1:]
critic_module = MLP( critic_module = MLP(
in_features=obs_space_shape[0], in_features=obs_space.shape[-1],
out_features=1, out_features=1,
num_cells=hidden_sizes, num_cells=hidden_sizes,
activation_class=getattr(nn, activation_fn), activation_class=getattr(nn, activation_fn),

View File

@ -1,9 +1,5 @@
import torch import torch
try: import cpp_projection
import cpp_projection
cpp_projection_available = True
except ImportError:
cpp_projection_available = False
import numpy as np import numpy as np
from .base_projection import BaseProjection from .base_projection import BaseProjection
from tensordict.nn import TensorDictModule from tensordict.nn import TensorDictModule

View File

@ -8,7 +8,7 @@
description = "Minimalistic and efficient implementations of PPO and TRPL for torchrl" description = "Minimalistic and efficient implementations of PPO and TRPL for torchrl"
authors = [{name = "Dominik Roth", email = "mail@dominik-roth.eu"}] authors = [{name = "Dominik Roth", email = "mail@dominik-roth.eu"}]
readme = "README.md" readme = "README.md"
requires-python = ">=3.7" requires-python = ">=3.7,<3.12"
classifiers = [ classifiers = [
"Development Status :: 3 - Alpha", "Development Status :: 3 - Alpha",
"Intended Audience :: Developers", "Intended Audience :: Developers",
@ -23,10 +23,9 @@
dependencies = [ dependencies = [
"numpy", "numpy",
"torch", "torch",
"gymnasium<1.0", "gymnasium",
"tensordict", "tensordict",
"torchrl", "torchrl",
"pytest",
] ]
[project.urls] [project.urls]
@ -34,4 +33,3 @@
[project.optional-dependencies] [project.optional-dependencies]
dev = ["pytest"] dev = ["pytest"]
box2d = ["swig", "gymnasium[box2d]"]

View File

@ -3,11 +3,9 @@ import numpy as np
from fancy_rl import PPO from fancy_rl import PPO
import gymnasium as gym import gymnasium as gym
from torchrl.envs import GymEnv from torchrl.envs import GymEnv
import torch as th
from tensordict import TensorDict
def simple_env(): def simple_env():
return gym.make('LunarLander-v2') return GymEnv('LunarLander-v2', continuous=True)
def test_ppo_instantiation(): def test_ppo_instantiation():
ppo = PPO(simple_env) ppo = PPO(simple_env)
@ -17,38 +15,69 @@ def test_ppo_instantiation_from_str():
ppo = PPO('CartPole-v1') ppo = PPO('CartPole-v1')
assert isinstance(ppo, PPO) assert isinstance(ppo, PPO)
def test_ppo_instantiation_from_make():
ppo = PPO(gym.make('CartPole-v1'))
assert isinstance(ppo, PPO)
@pytest.mark.parametrize("learning_rate", [1e-4, 3e-4, 1e-3])
@pytest.mark.parametrize("n_steps", [1024, 2048])
@pytest.mark.parametrize("batch_size", [32, 64, 128])
@pytest.mark.parametrize("n_epochs", [5, 10])
@pytest.mark.parametrize("gamma", [0.95, 0.99])
@pytest.mark.parametrize("clip_range", [0.1, 0.2, 0.3])
def test_ppo_initialization_with_different_hps(learning_rate, n_steps, batch_size, n_epochs, gamma, clip_range):
ppo = PPO(
simple_env,
learning_rate=learning_rate,
n_steps=n_steps,
batch_size=batch_size,
n_epochs=n_epochs,
gamma=gamma,
clip_range=clip_range
)
assert ppo.learning_rate == learning_rate
assert ppo.n_steps == n_steps
assert ppo.batch_size == batch_size
assert ppo.n_epochs == n_epochs
assert ppo.gamma == gamma
assert ppo.clip_range == clip_range
def test_ppo_predict(): def test_ppo_predict():
ppo = PPO(simple_env) ppo = PPO(simple_env)
env = ppo.make_env() env = ppo.make_env()
obs, _ = env.reset()
action, _ = ppo.predict(obs)
assert isinstance(action, np.ndarray)
assert action.shape == env.action_space.shape
def test_ppo_learn():
ppo = PPO(simple_env, n_steps=64, batch_size=32)
env = ppo.make_env()
obs = env.reset() obs = env.reset()
action = ppo.predict(obs) for _ in range(64):
assert isinstance(action, TensorDict) action, _next_state = ppo.predict(obs)
obs, reward, done, truncated, _ = env.step(action)
# Handle both single and composite action spaces if done or truncated:
if isinstance(env.action_space, list): obs = env.reset()
expected_shape = (len(env.action_space),) + env.action_space[0].shape
else:
expected_shape = env.action_space.shape
assert action["action"].shape == expected_shape
def test_ppo_training(): def test_ppo_training():
ppo = PPO(simple_env, total_timesteps=100) ppo = PPO(simple_env, total_timesteps=10000)
env = ppo.make_env() env = ppo.make_env()
initial_performance = evaluate_policy(ppo, env) initial_performance = evaluate_policy(ppo, env)
ppo.train() ppo.train()
final_performance = evaluate_policy(ppo, env) final_performance = evaluate_policy(ppo, env)
def evaluate_policy(policy, env, n_eval_episodes=3): assert final_performance > initial_performance, "PPO should improve performance after training"
def evaluate_policy(policy, env, n_eval_episodes=10):
total_reward = 0 total_reward = 0
for _ in range(n_eval_episodes): for _ in range(n_eval_episodes):
tensordict = env.reset() obs = env.reset()
done = False done = False
while not done: while not done:
action = policy.predict(tensordict) action, _next_state = policy.predict(obs)
next_tensordict = env.step(action).get("next") obs, reward, terminated, truncated, _ = env.step(action)
total_reward += next_tensordict["reward"] total_reward += reward
done = next_tensordict["done"] done = terminated or truncated
tensordict = next_tensordict
return total_reward / n_eval_episodes return total_reward / n_eval_episodes

View File

@ -2,10 +2,9 @@ import pytest
import numpy as np import numpy as np
from fancy_rl import TRPL from fancy_rl import TRPL
import gymnasium as gym import gymnasium as gym
from tensordict import TensorDict
def simple_env(): def simple_env():
return gym.make('Pendulum-v1') return gym.make('LunarLander-v2', continuous=True)
def test_trpl_instantiation(): def test_trpl_instantiation():
trpl = TRPL(simple_env) trpl = TRPL(simple_env)
@ -35,41 +34,52 @@ def test_trpl_initialization_with_different_hps(learning_rate, n_steps, batch_si
assert trpl.n_steps == n_steps assert trpl.n_steps == n_steps
assert trpl.batch_size == batch_size assert trpl.batch_size == batch_size
assert trpl.gamma == gamma assert trpl.gamma == gamma
assert trpl.projection.mean_bound == trust_region_bound_mean assert trpl.projection.trust_region_bound_mean == trust_region_bound_mean
assert trpl.projection.cov_bound == trust_region_bound_cov assert trpl.projection.trust_region_bound_cov == trust_region_bound_cov
def test_trpl_predict(): def test_trpl_predict():
trpl = TRPL(simple_env) trpl = TRPL(simple_env)
env = trpl.make_env() env = trpl.make_env()
obs = env.reset() obs, _ = env.reset()
action = trpl.predict(obs) action, _ = trpl.predict(obs)
assert isinstance(action, TensorDict) assert isinstance(action, np.ndarray)
assert action.shape == env.action_space.shape
# Handle both single and composite action spaces def test_trpl_learn():
if isinstance(env.action_space, list): trpl = TRPL(simple_env, n_steps=64, batch_size=32)
expected_shape = (len(env.action_space),) + env.action_space[0].shape env = trpl.make_env()
else: obs, _ = env.reset()
expected_shape = env.action_space.shape for _ in range(64):
action, _ = trpl.predict(obs)
next_obs, reward, done, truncated, _ = env.step(action)
trpl.store_transition(obs, action, reward, done, next_obs)
obs = next_obs
if done or truncated:
obs, _ = env.reset()
assert action["action"].shape == expected_shape loss = trpl.learn()
assert isinstance(loss, dict)
assert "policy_loss" in loss
assert "value_loss" in loss
def test_trpl_training(): def test_trpl_training():
trpl = TRPL(simple_env, total_timesteps=100) trpl = TRPL(simple_env, total_timesteps=10000)
env = trpl.make_env() env = trpl.make_env()
initial_performance = evaluate_policy(trpl, env) initial_performance = evaluate_policy(trpl, env)
trpl.train() trpl.train()
final_performance = evaluate_policy(trpl, env) final_performance = evaluate_policy(trpl, env)
def evaluate_policy(policy, env, n_eval_episodes=3): assert final_performance > initial_performance, "TRPL should improve performance after training"
def evaluate_policy(policy, env, n_eval_episodes=10):
total_reward = 0 total_reward = 0
for _ in range(n_eval_episodes): for _ in range(n_eval_episodes):
tensordict = env.reset() obs, _ = env.reset()
done = False done = False
while not done: while not done:
action = policy.predict(tensordict) action, _ = policy.predict(obs)
next_tensordict = env.step(action).get("next") obs, reward, terminated, truncated, _ = env.step(action)
total_reward += next_tensordict["reward"] total_reward += reward
done = next_tensordict["done"] done = terminated or truncated
tensordict = next_tensordict
return total_reward / n_eval_episodes return total_reward / n_eval_episodes