Compare commits

..

5 Commits

9 changed files with 147 additions and 121 deletions

1
.gitignore vendored
View File

@ -1,5 +1,6 @@
__pycache__
.venv
.vscode
wandb
*.egg-info/
test.py

View File

@ -1,9 +1,10 @@
import torch
import gymnasium as gym
from torchrl.envs import GymEnv, TransformedEnv, Compose, RewardSum, StepCounter, ParallelEnv
from torchrl.envs import GymEnv, TransformedEnv, Compose, RewardSum, StepCounter, SerialEnv
from torchrl.record import VideoRecorder
from abc import ABC
import pdb
import numpy as np
from tensordict import TensorDict
from torchrl.envs import GymWrapper, TransformedEnv
from torchrl.envs import BatchSizeTransform
@ -56,27 +57,31 @@ class Algo(ABC):
return env
def _wrap_env(self, env_spec):
# If given an existing env, ensure it's properly batched
if isinstance(env_spec, (GymEnv, GymWrapper)):
if not env_spec.batch_size:
raise ValueError("Environment must be batched")
return env_spec
# Handle callable without wrapping the recursive call
if callable(env_spec):
return self._wrap_env(env_spec())
# Create new batched environment using SerialEnv
if isinstance(env_spec, str):
env = GymEnv(env_spec, device=self.device)
env = SerialEnv(1, lambda: GymEnv(env_spec, device=self.device))
elif isinstance(env_spec, gym.Env):
env = GymWrapper(env_spec, device=self.device)
elif isinstance(env_spec, GymEnv):
env = env_spec
elif callable(env_spec):
base_env = env_spec()
return self._wrap_env(base_env)
wrapped_env = GymWrapper(env_spec, device=self.device)
if wrapped_env.batch_size:
env = wrapped_env
else:
env = SerialEnv(1, lambda: wrapped_env)
else:
raise ValueError(
f"env_spec must be a string, callable, Gymnasium environment, or GymEnv, "
f"got {type(env_spec)}"
)
if not env.batch_size:
env = TransformedEnv(
env,
BatchSizeTransform(batch_size=torch.Size([1]))
)
return env
def train_step(self, batch):
@ -90,18 +95,21 @@ class Algo(ABC):
def predict(
self,
observation,
tensordict,
state=None,
deterministic=False
):
with torch.no_grad():
obs_tensor = torch.as_tensor(observation, device=self.device).unsqueeze(0)
td = TensorDict({"observation": obs_tensor})
# If numpy array, convert to TensorDict
if isinstance(tensordict, np.ndarray):
tensordict = TensorDict(
{"observation": torch.from_numpy(tensordict).float()},
batch_size=[]
)
action_td = self.prob_actor(td)
action = action_td["action"]
# Move to device
tensordict = tensordict.to(self.device)
# We're not using recurrent policies, so we'll always return None for the state
next_state = None
return action.squeeze(0).cpu().numpy(), next_state
# Get action from policy
action_td = self.prob_actor(tensordict)
return action_td

View File

@ -43,13 +43,13 @@ class PPO(OnPolicy):
env = self.make_env()
# Get spaces from specs for parallel env
obs_space = env.observation_spec
act_space = env.action_spec
self.obs_space = env.observation_spec
self.act_space = env.action_spec
self.discrete = isinstance(act_space, DiscreteTensorSpec)
self.discrete = isinstance(self.act_space, DiscreteTensorSpec)
self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
self.actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
self.critic = Critic(self.obs_space, critic_hidden_sizes, critic_activation_fn, device)
self.actor = Actor(self.obs_space, self.act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
if self.discrete:
distribution_class = torch.distributions.Categorical

View File

@ -14,6 +14,8 @@ from fancy_rl.objectives import TRPLLoss
from copy import deepcopy
from tensordict.nn import TensorDictModule
from tensordict import TensorDict
from torch.distributions import Categorical, MultivariateNormal, Normal
class ProjectedActor(TensorDictModule):
def __init__(self, raw_actor, old_actor, projection):
@ -26,6 +28,8 @@ class ProjectedActor(TensorDictModule):
self.raw_actor = raw_actor
self.old_actor = old_actor
self.projection = projection
self.discrete = raw_actor.discrete
self.full_covariance = raw_actor.full_covariance
class CombinedModule(nn.Module):
def __init__(self, raw_actor, old_actor, projection):
@ -35,12 +39,41 @@ class ProjectedActor(TensorDictModule):
self.projection = projection
def forward(self, tensordict):
# Convert the tuple outputs to TensorDict
raw_params = self.raw_actor(tensordict)
if isinstance(raw_params, tuple):
raw_params = TensorDict({
"loc": raw_params[0],
"scale": raw_params[1]
}, batch_size=[raw_params[0].shape[0]]) # Use the first dimension of the tensor as batch size
old_params = self.old_actor(tensordict)
combined_params = TensorDict({**raw_params, **{f"old_{key}": value for key, value in old_params.items()}}, batch_size=tensordict.batch_size)
if isinstance(old_params, tuple):
old_params = TensorDict({
"loc": old_params[0],
"scale": old_params[1]
}, batch_size=[old_params[0].shape[0]]) # Use the first dimension of the tensor as batch size
# Now combine them
combined_params = TensorDict({
**raw_params,
**{f"old_{key}": value for key, value in old_params.items()}
}, batch_size=[raw_params["loc"].shape[0]]) # Use the first dimension of loc tensor as batch size
projected_params = self.projection(combined_params)
return projected_params
def get_dist(self, tensordict):
# Forward the observation through the network
out = self.forward(tensordict)
if self.discrete:
return Categorical(logits=out["logits"])
else:
if self.full_covariance:
return MultivariateNormal(loc=out["loc"], scale_tril=out["scale_tril"])
else:
return Normal(loc=out["loc"], scale=out["scale"])
class TRPL(OnPolicy):
def __init__(
self,
@ -77,14 +110,16 @@ class TRPL(OnPolicy):
# Initialize environment to get observation and action space sizes
self.env_spec = env_spec
env = self.make_env()
obs_space = env.observation_space
act_space = env.action_space
assert not isinstance(act_space, DiscreteTensorSpec), "TRPL does not support discrete action spaces"
# Get spaces from specs for parallel env
self.obs_space = env.observation_spec
self.act_space = env.action_spec
self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
self.raw_actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
self.old_actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
assert not isinstance(self.act_space, DiscreteTensorSpec), "TRPL does not support discrete action spaces"
self.critic = Critic(self.obs_space, critic_hidden_sizes, critic_activation_fn, device)
self.raw_actor = Actor(self.obs_space, self.act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
self.old_actor = Actor(self.obs_space, self.act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
# Handle projection_class
if isinstance(projection_class, str):

View File

@ -4,6 +4,7 @@ from torchrl.modules import MLP
from torchrl.data.tensor_specs import DiscreteTensorSpec
from tensordict.nn.distributions import NormalParamExtractor
from tensordict import TensorDict
from torch.distributions import Categorical, MultivariateNormal, Normal
class Actor(TensorDictModule):
def __init__(self, obs_space, act_space, hidden_sizes, activation_fn, device, full_covariance=False):
@ -50,6 +51,17 @@ class Actor(TensorDictModule):
out_keys=out_keys
)
def get_dist(self, tensordict):
# Forward the observation through the network
out = self.forward(tensordict)
if self.discrete:
return Categorical(logits=out["logits"])
else:
if self.full_covariance:
return MultivariateNormal(loc=out["loc"], scale_tril=out["scale_tril"])
else:
return Normal(loc=out["loc"], scale=out["scale"])
class FullCovarianceNormalParamExtractor(nn.Module):
def __init__(self, action_dim):
super().__init__()
@ -65,8 +77,11 @@ class FullCovarianceNormalParamExtractor(nn.Module):
class Critic(TensorDictModule):
def __init__(self, obs_space, hidden_sizes, activation_fn, device):
obs_space = obs_space["observation"]
obs_space_shape = obs_space.shape[1:]
critic_module = MLP(
in_features=obs_space.shape[-1],
in_features=obs_space_shape[0],
out_features=1,
num_cells=hidden_sizes,
activation_class=getattr(nn, activation_fn),

View File

@ -1,5 +1,9 @@
import torch
import cpp_projection
try:
import cpp_projection
cpp_projection_available = True
except ImportError:
cpp_projection_available = False
import numpy as np
from .base_projection import BaseProjection
from tensordict.nn import TensorDictModule

View File

@ -8,7 +8,7 @@
description = "Minimalistic and efficient implementations of PPO and TRPL for torchrl"
authors = [{name = "Dominik Roth", email = "mail@dominik-roth.eu"}]
readme = "README.md"
requires-python = ">=3.7,<3.12"
requires-python = ">=3.7"
classifiers = [
"Development Status :: 3 - Alpha",
"Intended Audience :: Developers",
@ -23,9 +23,10 @@
dependencies = [
"numpy",
"torch",
"gymnasium",
"gymnasium<1.0",
"tensordict",
"torchrl",
"pytest",
]
[project.urls]
@ -33,3 +34,4 @@
[project.optional-dependencies]
dev = ["pytest"]
box2d = ["swig", "gymnasium[box2d]"]

View File

@ -3,9 +3,11 @@ import numpy as np
from fancy_rl import PPO
import gymnasium as gym
from torchrl.envs import GymEnv
import torch as th
from tensordict import TensorDict
def simple_env():
return GymEnv('LunarLander-v2', continuous=True)
return gym.make('LunarLander-v2')
def test_ppo_instantiation():
ppo = PPO(simple_env)
@ -15,69 +17,38 @@ def test_ppo_instantiation_from_str():
ppo = PPO('CartPole-v1')
assert isinstance(ppo, PPO)
def test_ppo_instantiation_from_make():
ppo = PPO(gym.make('CartPole-v1'))
assert isinstance(ppo, PPO)
@pytest.mark.parametrize("learning_rate", [1e-4, 3e-4, 1e-3])
@pytest.mark.parametrize("n_steps", [1024, 2048])
@pytest.mark.parametrize("batch_size", [32, 64, 128])
@pytest.mark.parametrize("n_epochs", [5, 10])
@pytest.mark.parametrize("gamma", [0.95, 0.99])
@pytest.mark.parametrize("clip_range", [0.1, 0.2, 0.3])
def test_ppo_initialization_with_different_hps(learning_rate, n_steps, batch_size, n_epochs, gamma, clip_range):
ppo = PPO(
simple_env,
learning_rate=learning_rate,
n_steps=n_steps,
batch_size=batch_size,
n_epochs=n_epochs,
gamma=gamma,
clip_range=clip_range
)
assert ppo.learning_rate == learning_rate
assert ppo.n_steps == n_steps
assert ppo.batch_size == batch_size
assert ppo.n_epochs == n_epochs
assert ppo.gamma == gamma
assert ppo.clip_range == clip_range
def test_ppo_predict():
ppo = PPO(simple_env)
env = ppo.make_env()
obs, _ = env.reset()
action, _ = ppo.predict(obs)
assert isinstance(action, np.ndarray)
assert action.shape == env.action_space.shape
obs = env.reset()
action = ppo.predict(obs)
assert isinstance(action, TensorDict)
def test_ppo_learn():
ppo = PPO(simple_env, n_steps=64, batch_size=32)
env = ppo.make_env()
obs = env.reset()
for _ in range(64):
action, _next_state = ppo.predict(obs)
obs, reward, done, truncated, _ = env.step(action)
if done or truncated:
obs = env.reset()
# Handle both single and composite action spaces
if isinstance(env.action_space, list):
expected_shape = (len(env.action_space),) + env.action_space[0].shape
else:
expected_shape = env.action_space.shape
assert action["action"].shape == expected_shape
def test_ppo_training():
ppo = PPO(simple_env, total_timesteps=10000)
ppo = PPO(simple_env, total_timesteps=100)
env = ppo.make_env()
initial_performance = evaluate_policy(ppo, env)
ppo.train()
final_performance = evaluate_policy(ppo, env)
assert final_performance > initial_performance, "PPO should improve performance after training"
def evaluate_policy(policy, env, n_eval_episodes=10):
def evaluate_policy(policy, env, n_eval_episodes=3):
total_reward = 0
for _ in range(n_eval_episodes):
obs = env.reset()
tensordict = env.reset()
done = False
while not done:
action, _next_state = policy.predict(obs)
obs, reward, terminated, truncated, _ = env.step(action)
total_reward += reward
done = terminated or truncated
action = policy.predict(tensordict)
next_tensordict = env.step(action).get("next")
total_reward += next_tensordict["reward"]
done = next_tensordict["done"]
tensordict = next_tensordict
return total_reward / n_eval_episodes

View File

@ -2,9 +2,10 @@ import pytest
import numpy as np
from fancy_rl import TRPL
import gymnasium as gym
from tensordict import TensorDict
def simple_env():
return gym.make('LunarLander-v2', continuous=True)
return gym.make('Pendulum-v1')
def test_trpl_instantiation():
trpl = TRPL(simple_env)
@ -34,52 +35,41 @@ def test_trpl_initialization_with_different_hps(learning_rate, n_steps, batch_si
assert trpl.n_steps == n_steps
assert trpl.batch_size == batch_size
assert trpl.gamma == gamma
assert trpl.projection.trust_region_bound_mean == trust_region_bound_mean
assert trpl.projection.trust_region_bound_cov == trust_region_bound_cov
assert trpl.projection.mean_bound == trust_region_bound_mean
assert trpl.projection.cov_bound == trust_region_bound_cov
def test_trpl_predict():
trpl = TRPL(simple_env)
env = trpl.make_env()
obs, _ = env.reset()
action, _ = trpl.predict(obs)
assert isinstance(action, np.ndarray)
assert action.shape == env.action_space.shape
obs = env.reset()
action = trpl.predict(obs)
assert isinstance(action, TensorDict)
def test_trpl_learn():
trpl = TRPL(simple_env, n_steps=64, batch_size=32)
env = trpl.make_env()
obs, _ = env.reset()
for _ in range(64):
action, _ = trpl.predict(obs)
next_obs, reward, done, truncated, _ = env.step(action)
trpl.store_transition(obs, action, reward, done, next_obs)
obs = next_obs
if done or truncated:
obs, _ = env.reset()
# Handle both single and composite action spaces
if isinstance(env.action_space, list):
expected_shape = (len(env.action_space),) + env.action_space[0].shape
else:
expected_shape = env.action_space.shape
loss = trpl.learn()
assert isinstance(loss, dict)
assert "policy_loss" in loss
assert "value_loss" in loss
assert action["action"].shape == expected_shape
def test_trpl_training():
trpl = TRPL(simple_env, total_timesteps=10000)
trpl = TRPL(simple_env, total_timesteps=100)
env = trpl.make_env()
initial_performance = evaluate_policy(trpl, env)
trpl.train()
final_performance = evaluate_policy(trpl, env)
assert final_performance > initial_performance, "TRPL should improve performance after training"
def evaluate_policy(policy, env, n_eval_episodes=10):
def evaluate_policy(policy, env, n_eval_episodes=3):
total_reward = 0
for _ in range(n_eval_episodes):
obs, _ = env.reset()
tensordict = env.reset()
done = False
while not done:
action, _ = policy.predict(obs)
obs, reward, terminated, truncated, _ = env.step(action)
total_reward += reward
done = terminated or truncated
action = policy.predict(tensordict)
next_tensordict = env.step(action).get("next")
total_reward += next_tensordict["reward"]
done = next_tensordict["done"]
tensordict = next_tensordict
return total_reward / n_eval_episodes