Compare commits

..

5 Commits

9 changed files with 147 additions and 121 deletions

1
.gitignore vendored
View File

@ -1,5 +1,6 @@
__pycache__ __pycache__
.venv .venv
.vscode
wandb wandb
*.egg-info/ *.egg-info/
test.py test.py

View File

@ -1,9 +1,10 @@
import torch import torch
import gymnasium as gym import gymnasium as gym
from torchrl.envs import GymEnv, TransformedEnv, Compose, RewardSum, StepCounter, ParallelEnv from torchrl.envs import GymEnv, TransformedEnv, Compose, RewardSum, StepCounter, SerialEnv
from torchrl.record import VideoRecorder from torchrl.record import VideoRecorder
from abc import ABC from abc import ABC
import pdb import pdb
import numpy as np
from tensordict import TensorDict from tensordict import TensorDict
from torchrl.envs import GymWrapper, TransformedEnv from torchrl.envs import GymWrapper, TransformedEnv
from torchrl.envs import BatchSizeTransform from torchrl.envs import BatchSizeTransform
@ -56,27 +57,31 @@ class Algo(ABC):
return env return env
def _wrap_env(self, env_spec): def _wrap_env(self, env_spec):
# If given an existing env, ensure it's properly batched
if isinstance(env_spec, (GymEnv, GymWrapper)):
if not env_spec.batch_size:
raise ValueError("Environment must be batched")
return env_spec
# Handle callable without wrapping the recursive call
if callable(env_spec):
return self._wrap_env(env_spec())
# Create new batched environment using SerialEnv
if isinstance(env_spec, str): if isinstance(env_spec, str):
env = GymEnv(env_spec, device=self.device) env = SerialEnv(1, lambda: GymEnv(env_spec, device=self.device))
elif isinstance(env_spec, gym.Env): elif isinstance(env_spec, gym.Env):
env = GymWrapper(env_spec, device=self.device) wrapped_env = GymWrapper(env_spec, device=self.device)
elif isinstance(env_spec, GymEnv): if wrapped_env.batch_size:
env = env_spec env = wrapped_env
elif callable(env_spec): else:
base_env = env_spec() env = SerialEnv(1, lambda: wrapped_env)
return self._wrap_env(base_env)
else: else:
raise ValueError( raise ValueError(
f"env_spec must be a string, callable, Gymnasium environment, or GymEnv, " f"env_spec must be a string, callable, Gymnasium environment, or GymEnv, "
f"got {type(env_spec)}" f"got {type(env_spec)}"
) )
if not env.batch_size:
env = TransformedEnv(
env,
BatchSizeTransform(batch_size=torch.Size([1]))
)
return env return env
def train_step(self, batch): def train_step(self, batch):
@ -90,18 +95,21 @@ class Algo(ABC):
def predict( def predict(
self, self,
observation, tensordict,
state=None, state=None,
deterministic=False deterministic=False
): ):
with torch.no_grad(): with torch.no_grad():
obs_tensor = torch.as_tensor(observation, device=self.device).unsqueeze(0) # If numpy array, convert to TensorDict
td = TensorDict({"observation": obs_tensor}) if isinstance(tensordict, np.ndarray):
tensordict = TensorDict(
{"observation": torch.from_numpy(tensordict).float()},
batch_size=[]
)
action_td = self.prob_actor(td) # Move to device
action = action_td["action"] tensordict = tensordict.to(self.device)
# We're not using recurrent policies, so we'll always return None for the state # Get action from policy
next_state = None action_td = self.prob_actor(tensordict)
return action_td
return action.squeeze(0).cpu().numpy(), next_state

View File

@ -43,13 +43,13 @@ class PPO(OnPolicy):
env = self.make_env() env = self.make_env()
# Get spaces from specs for parallel env # Get spaces from specs for parallel env
obs_space = env.observation_spec self.obs_space = env.observation_spec
act_space = env.action_spec self.act_space = env.action_spec
self.discrete = isinstance(act_space, DiscreteTensorSpec) self.discrete = isinstance(self.act_space, DiscreteTensorSpec)
self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device) self.critic = Critic(self.obs_space, critic_hidden_sizes, critic_activation_fn, device)
self.actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance) self.actor = Actor(self.obs_space, self.act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
if self.discrete: if self.discrete:
distribution_class = torch.distributions.Categorical distribution_class = torch.distributions.Categorical

View File

@ -14,6 +14,8 @@ from fancy_rl.objectives import TRPLLoss
from copy import deepcopy from copy import deepcopy
from tensordict.nn import TensorDictModule from tensordict.nn import TensorDictModule
from tensordict import TensorDict from tensordict import TensorDict
from torch.distributions import Categorical, MultivariateNormal, Normal
class ProjectedActor(TensorDictModule): class ProjectedActor(TensorDictModule):
def __init__(self, raw_actor, old_actor, projection): def __init__(self, raw_actor, old_actor, projection):
@ -26,6 +28,8 @@ class ProjectedActor(TensorDictModule):
self.raw_actor = raw_actor self.raw_actor = raw_actor
self.old_actor = old_actor self.old_actor = old_actor
self.projection = projection self.projection = projection
self.discrete = raw_actor.discrete
self.full_covariance = raw_actor.full_covariance
class CombinedModule(nn.Module): class CombinedModule(nn.Module):
def __init__(self, raw_actor, old_actor, projection): def __init__(self, raw_actor, old_actor, projection):
@ -35,12 +39,41 @@ class ProjectedActor(TensorDictModule):
self.projection = projection self.projection = projection
def forward(self, tensordict): def forward(self, tensordict):
# Convert the tuple outputs to TensorDict
raw_params = self.raw_actor(tensordict) raw_params = self.raw_actor(tensordict)
if isinstance(raw_params, tuple):
raw_params = TensorDict({
"loc": raw_params[0],
"scale": raw_params[1]
}, batch_size=[raw_params[0].shape[0]]) # Use the first dimension of the tensor as batch size
old_params = self.old_actor(tensordict) old_params = self.old_actor(tensordict)
combined_params = TensorDict({**raw_params, **{f"old_{key}": value for key, value in old_params.items()}}, batch_size=tensordict.batch_size) if isinstance(old_params, tuple):
old_params = TensorDict({
"loc": old_params[0],
"scale": old_params[1]
}, batch_size=[old_params[0].shape[0]]) # Use the first dimension of the tensor as batch size
# Now combine them
combined_params = TensorDict({
**raw_params,
**{f"old_{key}": value for key, value in old_params.items()}
}, batch_size=[raw_params["loc"].shape[0]]) # Use the first dimension of loc tensor as batch size
projected_params = self.projection(combined_params) projected_params = self.projection(combined_params)
return projected_params return projected_params
def get_dist(self, tensordict):
# Forward the observation through the network
out = self.forward(tensordict)
if self.discrete:
return Categorical(logits=out["logits"])
else:
if self.full_covariance:
return MultivariateNormal(loc=out["loc"], scale_tril=out["scale_tril"])
else:
return Normal(loc=out["loc"], scale=out["scale"])
class TRPL(OnPolicy): class TRPL(OnPolicy):
def __init__( def __init__(
self, self,
@ -77,14 +110,16 @@ class TRPL(OnPolicy):
# Initialize environment to get observation and action space sizes # Initialize environment to get observation and action space sizes
self.env_spec = env_spec self.env_spec = env_spec
env = self.make_env() env = self.make_env()
obs_space = env.observation_space
act_space = env.action_space
assert not isinstance(act_space, DiscreteTensorSpec), "TRPL does not support discrete action spaces" # Get spaces from specs for parallel env
self.obs_space = env.observation_spec
self.act_space = env.action_spec
self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device) assert not isinstance(self.act_space, DiscreteTensorSpec), "TRPL does not support discrete action spaces"
self.raw_actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
self.old_actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance) self.critic = Critic(self.obs_space, critic_hidden_sizes, critic_activation_fn, device)
self.raw_actor = Actor(self.obs_space, self.act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
self.old_actor = Actor(self.obs_space, self.act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
# Handle projection_class # Handle projection_class
if isinstance(projection_class, str): if isinstance(projection_class, str):

View File

@ -4,6 +4,7 @@ from torchrl.modules import MLP
from torchrl.data.tensor_specs import DiscreteTensorSpec from torchrl.data.tensor_specs import DiscreteTensorSpec
from tensordict.nn.distributions import NormalParamExtractor from tensordict.nn.distributions import NormalParamExtractor
from tensordict import TensorDict from tensordict import TensorDict
from torch.distributions import Categorical, MultivariateNormal, Normal
class Actor(TensorDictModule): class Actor(TensorDictModule):
def __init__(self, obs_space, act_space, hidden_sizes, activation_fn, device, full_covariance=False): def __init__(self, obs_space, act_space, hidden_sizes, activation_fn, device, full_covariance=False):
@ -50,6 +51,17 @@ class Actor(TensorDictModule):
out_keys=out_keys out_keys=out_keys
) )
def get_dist(self, tensordict):
# Forward the observation through the network
out = self.forward(tensordict)
if self.discrete:
return Categorical(logits=out["logits"])
else:
if self.full_covariance:
return MultivariateNormal(loc=out["loc"], scale_tril=out["scale_tril"])
else:
return Normal(loc=out["loc"], scale=out["scale"])
class FullCovarianceNormalParamExtractor(nn.Module): class FullCovarianceNormalParamExtractor(nn.Module):
def __init__(self, action_dim): def __init__(self, action_dim):
super().__init__() super().__init__()
@ -65,8 +77,11 @@ class FullCovarianceNormalParamExtractor(nn.Module):
class Critic(TensorDictModule): class Critic(TensorDictModule):
def __init__(self, obs_space, hidden_sizes, activation_fn, device): def __init__(self, obs_space, hidden_sizes, activation_fn, device):
obs_space = obs_space["observation"]
obs_space_shape = obs_space.shape[1:]
critic_module = MLP( critic_module = MLP(
in_features=obs_space.shape[-1], in_features=obs_space_shape[0],
out_features=1, out_features=1,
num_cells=hidden_sizes, num_cells=hidden_sizes,
activation_class=getattr(nn, activation_fn), activation_class=getattr(nn, activation_fn),

View File

@ -1,5 +1,9 @@
import torch import torch
try:
import cpp_projection import cpp_projection
cpp_projection_available = True
except ImportError:
cpp_projection_available = False
import numpy as np import numpy as np
from .base_projection import BaseProjection from .base_projection import BaseProjection
from tensordict.nn import TensorDictModule from tensordict.nn import TensorDictModule

View File

@ -8,7 +8,7 @@
description = "Minimalistic and efficient implementations of PPO and TRPL for torchrl" description = "Minimalistic and efficient implementations of PPO and TRPL for torchrl"
authors = [{name = "Dominik Roth", email = "mail@dominik-roth.eu"}] authors = [{name = "Dominik Roth", email = "mail@dominik-roth.eu"}]
readme = "README.md" readme = "README.md"
requires-python = ">=3.7,<3.12" requires-python = ">=3.7"
classifiers = [ classifiers = [
"Development Status :: 3 - Alpha", "Development Status :: 3 - Alpha",
"Intended Audience :: Developers", "Intended Audience :: Developers",
@ -23,9 +23,10 @@
dependencies = [ dependencies = [
"numpy", "numpy",
"torch", "torch",
"gymnasium", "gymnasium<1.0",
"tensordict", "tensordict",
"torchrl", "torchrl",
"pytest",
] ]
[project.urls] [project.urls]
@ -33,3 +34,4 @@
[project.optional-dependencies] [project.optional-dependencies]
dev = ["pytest"] dev = ["pytest"]
box2d = ["swig", "gymnasium[box2d]"]

View File

@ -3,9 +3,11 @@ import numpy as np
from fancy_rl import PPO from fancy_rl import PPO
import gymnasium as gym import gymnasium as gym
from torchrl.envs import GymEnv from torchrl.envs import GymEnv
import torch as th
from tensordict import TensorDict
def simple_env(): def simple_env():
return GymEnv('LunarLander-v2', continuous=True) return gym.make('LunarLander-v2')
def test_ppo_instantiation(): def test_ppo_instantiation():
ppo = PPO(simple_env) ppo = PPO(simple_env)
@ -15,69 +17,38 @@ def test_ppo_instantiation_from_str():
ppo = PPO('CartPole-v1') ppo = PPO('CartPole-v1')
assert isinstance(ppo, PPO) assert isinstance(ppo, PPO)
def test_ppo_instantiation_from_make():
ppo = PPO(gym.make('CartPole-v1'))
assert isinstance(ppo, PPO)
@pytest.mark.parametrize("learning_rate", [1e-4, 3e-4, 1e-3])
@pytest.mark.parametrize("n_steps", [1024, 2048])
@pytest.mark.parametrize("batch_size", [32, 64, 128])
@pytest.mark.parametrize("n_epochs", [5, 10])
@pytest.mark.parametrize("gamma", [0.95, 0.99])
@pytest.mark.parametrize("clip_range", [0.1, 0.2, 0.3])
def test_ppo_initialization_with_different_hps(learning_rate, n_steps, batch_size, n_epochs, gamma, clip_range):
ppo = PPO(
simple_env,
learning_rate=learning_rate,
n_steps=n_steps,
batch_size=batch_size,
n_epochs=n_epochs,
gamma=gamma,
clip_range=clip_range
)
assert ppo.learning_rate == learning_rate
assert ppo.n_steps == n_steps
assert ppo.batch_size == batch_size
assert ppo.n_epochs == n_epochs
assert ppo.gamma == gamma
assert ppo.clip_range == clip_range
def test_ppo_predict(): def test_ppo_predict():
ppo = PPO(simple_env) ppo = PPO(simple_env)
env = ppo.make_env() env = ppo.make_env()
obs, _ = env.reset() obs = env.reset()
action, _ = ppo.predict(obs) action = ppo.predict(obs)
assert isinstance(action, np.ndarray) assert isinstance(action, TensorDict)
assert action.shape == env.action_space.shape
def test_ppo_learn(): # Handle both single and composite action spaces
ppo = PPO(simple_env, n_steps=64, batch_size=32) if isinstance(env.action_space, list):
env = ppo.make_env() expected_shape = (len(env.action_space),) + env.action_space[0].shape
obs = env.reset() else:
for _ in range(64): expected_shape = env.action_space.shape
action, _next_state = ppo.predict(obs)
obs, reward, done, truncated, _ = env.step(action) assert action["action"].shape == expected_shape
if done or truncated:
obs = env.reset()
def test_ppo_training(): def test_ppo_training():
ppo = PPO(simple_env, total_timesteps=10000) ppo = PPO(simple_env, total_timesteps=100)
env = ppo.make_env() env = ppo.make_env()
initial_performance = evaluate_policy(ppo, env) initial_performance = evaluate_policy(ppo, env)
ppo.train() ppo.train()
final_performance = evaluate_policy(ppo, env) final_performance = evaluate_policy(ppo, env)
assert final_performance > initial_performance, "PPO should improve performance after training" def evaluate_policy(policy, env, n_eval_episodes=3):
def evaluate_policy(policy, env, n_eval_episodes=10):
total_reward = 0 total_reward = 0
for _ in range(n_eval_episodes): for _ in range(n_eval_episodes):
obs = env.reset() tensordict = env.reset()
done = False done = False
while not done: while not done:
action, _next_state = policy.predict(obs) action = policy.predict(tensordict)
obs, reward, terminated, truncated, _ = env.step(action) next_tensordict = env.step(action).get("next")
total_reward += reward total_reward += next_tensordict["reward"]
done = terminated or truncated done = next_tensordict["done"]
tensordict = next_tensordict
return total_reward / n_eval_episodes return total_reward / n_eval_episodes

View File

@ -2,9 +2,10 @@ import pytest
import numpy as np import numpy as np
from fancy_rl import TRPL from fancy_rl import TRPL
import gymnasium as gym import gymnasium as gym
from tensordict import TensorDict
def simple_env(): def simple_env():
return gym.make('LunarLander-v2', continuous=True) return gym.make('Pendulum-v1')
def test_trpl_instantiation(): def test_trpl_instantiation():
trpl = TRPL(simple_env) trpl = TRPL(simple_env)
@ -34,52 +35,41 @@ def test_trpl_initialization_with_different_hps(learning_rate, n_steps, batch_si
assert trpl.n_steps == n_steps assert trpl.n_steps == n_steps
assert trpl.batch_size == batch_size assert trpl.batch_size == batch_size
assert trpl.gamma == gamma assert trpl.gamma == gamma
assert trpl.projection.trust_region_bound_mean == trust_region_bound_mean assert trpl.projection.mean_bound == trust_region_bound_mean
assert trpl.projection.trust_region_bound_cov == trust_region_bound_cov assert trpl.projection.cov_bound == trust_region_bound_cov
def test_trpl_predict(): def test_trpl_predict():
trpl = TRPL(simple_env) trpl = TRPL(simple_env)
env = trpl.make_env() env = trpl.make_env()
obs, _ = env.reset() obs = env.reset()
action, _ = trpl.predict(obs) action = trpl.predict(obs)
assert isinstance(action, np.ndarray) assert isinstance(action, TensorDict)
assert action.shape == env.action_space.shape
def test_trpl_learn(): # Handle both single and composite action spaces
trpl = TRPL(simple_env, n_steps=64, batch_size=32) if isinstance(env.action_space, list):
env = trpl.make_env() expected_shape = (len(env.action_space),) + env.action_space[0].shape
obs, _ = env.reset() else:
for _ in range(64): expected_shape = env.action_space.shape
action, _ = trpl.predict(obs)
next_obs, reward, done, truncated, _ = env.step(action)
trpl.store_transition(obs, action, reward, done, next_obs)
obs = next_obs
if done or truncated:
obs, _ = env.reset()
loss = trpl.learn() assert action["action"].shape == expected_shape
assert isinstance(loss, dict)
assert "policy_loss" in loss
assert "value_loss" in loss
def test_trpl_training(): def test_trpl_training():
trpl = TRPL(simple_env, total_timesteps=10000) trpl = TRPL(simple_env, total_timesteps=100)
env = trpl.make_env() env = trpl.make_env()
initial_performance = evaluate_policy(trpl, env) initial_performance = evaluate_policy(trpl, env)
trpl.train() trpl.train()
final_performance = evaluate_policy(trpl, env) final_performance = evaluate_policy(trpl, env)
assert final_performance > initial_performance, "TRPL should improve performance after training" def evaluate_policy(policy, env, n_eval_episodes=3):
def evaluate_policy(policy, env, n_eval_episodes=10):
total_reward = 0 total_reward = 0
for _ in range(n_eval_episodes): for _ in range(n_eval_episodes):
obs, _ = env.reset() tensordict = env.reset()
done = False done = False
while not done: while not done:
action, _ = policy.predict(obs) action = policy.predict(tensordict)
obs, reward, terminated, truncated, _ = env.step(action) next_tensordict = env.step(action).get("next")
total_reward += reward total_reward += next_tensordict["reward"]
done = terminated or truncated done = next_tensordict["done"]
tensordict = next_tensordict
return total_reward / n_eval_episodes return total_reward / n_eval_episodes