Compare commits
5 Commits
e938018494
...
3816adef9a
Author | SHA1 | Date | |
---|---|---|---|
3816adef9a | |||
04be117a95 | |||
fe7c7b3db0 | |||
90666a695c | |||
c1189351cf |
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,5 +1,6 @@
|
||||
__pycache__
|
||||
.venv
|
||||
.vscode
|
||||
wandb
|
||||
*.egg-info/
|
||||
test.py
|
||||
|
@ -1,9 +1,10 @@
|
||||
import torch
|
||||
import gymnasium as gym
|
||||
from torchrl.envs import GymEnv, TransformedEnv, Compose, RewardSum, StepCounter, ParallelEnv
|
||||
from torchrl.envs import GymEnv, TransformedEnv, Compose, RewardSum, StepCounter, SerialEnv
|
||||
from torchrl.record import VideoRecorder
|
||||
from abc import ABC
|
||||
import pdb
|
||||
import numpy as np
|
||||
from tensordict import TensorDict
|
||||
from torchrl.envs import GymWrapper, TransformedEnv
|
||||
from torchrl.envs import BatchSizeTransform
|
||||
@ -56,27 +57,31 @@ class Algo(ABC):
|
||||
return env
|
||||
|
||||
def _wrap_env(self, env_spec):
|
||||
# If given an existing env, ensure it's properly batched
|
||||
if isinstance(env_spec, (GymEnv, GymWrapper)):
|
||||
if not env_spec.batch_size:
|
||||
raise ValueError("Environment must be batched")
|
||||
return env_spec
|
||||
|
||||
# Handle callable without wrapping the recursive call
|
||||
if callable(env_spec):
|
||||
return self._wrap_env(env_spec())
|
||||
|
||||
# Create new batched environment using SerialEnv
|
||||
if isinstance(env_spec, str):
|
||||
env = GymEnv(env_spec, device=self.device)
|
||||
env = SerialEnv(1, lambda: GymEnv(env_spec, device=self.device))
|
||||
elif isinstance(env_spec, gym.Env):
|
||||
env = GymWrapper(env_spec, device=self.device)
|
||||
elif isinstance(env_spec, GymEnv):
|
||||
env = env_spec
|
||||
elif callable(env_spec):
|
||||
base_env = env_spec()
|
||||
return self._wrap_env(base_env)
|
||||
wrapped_env = GymWrapper(env_spec, device=self.device)
|
||||
if wrapped_env.batch_size:
|
||||
env = wrapped_env
|
||||
else:
|
||||
env = SerialEnv(1, lambda: wrapped_env)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"env_spec must be a string, callable, Gymnasium environment, or GymEnv, "
|
||||
f"got {type(env_spec)}"
|
||||
)
|
||||
|
||||
if not env.batch_size:
|
||||
env = TransformedEnv(
|
||||
env,
|
||||
BatchSizeTransform(batch_size=torch.Size([1]))
|
||||
)
|
||||
|
||||
return env
|
||||
|
||||
def train_step(self, batch):
|
||||
@ -90,18 +95,21 @@ class Algo(ABC):
|
||||
|
||||
def predict(
|
||||
self,
|
||||
observation,
|
||||
tensordict,
|
||||
state=None,
|
||||
deterministic=False
|
||||
):
|
||||
with torch.no_grad():
|
||||
obs_tensor = torch.as_tensor(observation, device=self.device).unsqueeze(0)
|
||||
td = TensorDict({"observation": obs_tensor})
|
||||
# If numpy array, convert to TensorDict
|
||||
if isinstance(tensordict, np.ndarray):
|
||||
tensordict = TensorDict(
|
||||
{"observation": torch.from_numpy(tensordict).float()},
|
||||
batch_size=[]
|
||||
)
|
||||
|
||||
action_td = self.prob_actor(td)
|
||||
action = action_td["action"]
|
||||
# Move to device
|
||||
tensordict = tensordict.to(self.device)
|
||||
|
||||
# We're not using recurrent policies, so we'll always return None for the state
|
||||
next_state = None
|
||||
|
||||
return action.squeeze(0).cpu().numpy(), next_state
|
||||
# Get action from policy
|
||||
action_td = self.prob_actor(tensordict)
|
||||
return action_td
|
||||
|
@ -43,13 +43,13 @@ class PPO(OnPolicy):
|
||||
env = self.make_env()
|
||||
|
||||
# Get spaces from specs for parallel env
|
||||
obs_space = env.observation_spec
|
||||
act_space = env.action_spec
|
||||
self.obs_space = env.observation_spec
|
||||
self.act_space = env.action_spec
|
||||
|
||||
self.discrete = isinstance(act_space, DiscreteTensorSpec)
|
||||
self.discrete = isinstance(self.act_space, DiscreteTensorSpec)
|
||||
|
||||
self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
|
||||
self.actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
|
||||
self.critic = Critic(self.obs_space, critic_hidden_sizes, critic_activation_fn, device)
|
||||
self.actor = Actor(self.obs_space, self.act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
|
||||
|
||||
if self.discrete:
|
||||
distribution_class = torch.distributions.Categorical
|
||||
|
@ -14,6 +14,8 @@ from fancy_rl.objectives import TRPLLoss
|
||||
from copy import deepcopy
|
||||
from tensordict.nn import TensorDictModule
|
||||
from tensordict import TensorDict
|
||||
from torch.distributions import Categorical, MultivariateNormal, Normal
|
||||
|
||||
|
||||
class ProjectedActor(TensorDictModule):
|
||||
def __init__(self, raw_actor, old_actor, projection):
|
||||
@ -26,6 +28,8 @@ class ProjectedActor(TensorDictModule):
|
||||
self.raw_actor = raw_actor
|
||||
self.old_actor = old_actor
|
||||
self.projection = projection
|
||||
self.discrete = raw_actor.discrete
|
||||
self.full_covariance = raw_actor.full_covariance
|
||||
|
||||
class CombinedModule(nn.Module):
|
||||
def __init__(self, raw_actor, old_actor, projection):
|
||||
@ -35,12 +39,41 @@ class ProjectedActor(TensorDictModule):
|
||||
self.projection = projection
|
||||
|
||||
def forward(self, tensordict):
|
||||
# Convert the tuple outputs to TensorDict
|
||||
raw_params = self.raw_actor(tensordict)
|
||||
if isinstance(raw_params, tuple):
|
||||
raw_params = TensorDict({
|
||||
"loc": raw_params[0],
|
||||
"scale": raw_params[1]
|
||||
}, batch_size=[raw_params[0].shape[0]]) # Use the first dimension of the tensor as batch size
|
||||
|
||||
old_params = self.old_actor(tensordict)
|
||||
combined_params = TensorDict({**raw_params, **{f"old_{key}": value for key, value in old_params.items()}}, batch_size=tensordict.batch_size)
|
||||
if isinstance(old_params, tuple):
|
||||
old_params = TensorDict({
|
||||
"loc": old_params[0],
|
||||
"scale": old_params[1]
|
||||
}, batch_size=[old_params[0].shape[0]]) # Use the first dimension of the tensor as batch size
|
||||
|
||||
# Now combine them
|
||||
combined_params = TensorDict({
|
||||
**raw_params,
|
||||
**{f"old_{key}": value for key, value in old_params.items()}
|
||||
}, batch_size=[raw_params["loc"].shape[0]]) # Use the first dimension of loc tensor as batch size
|
||||
|
||||
projected_params = self.projection(combined_params)
|
||||
return projected_params
|
||||
|
||||
def get_dist(self, tensordict):
|
||||
# Forward the observation through the network
|
||||
out = self.forward(tensordict)
|
||||
if self.discrete:
|
||||
return Categorical(logits=out["logits"])
|
||||
else:
|
||||
if self.full_covariance:
|
||||
return MultivariateNormal(loc=out["loc"], scale_tril=out["scale_tril"])
|
||||
else:
|
||||
return Normal(loc=out["loc"], scale=out["scale"])
|
||||
|
||||
class TRPL(OnPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
@ -77,14 +110,16 @@ class TRPL(OnPolicy):
|
||||
# Initialize environment to get observation and action space sizes
|
||||
self.env_spec = env_spec
|
||||
env = self.make_env()
|
||||
obs_space = env.observation_space
|
||||
act_space = env.action_space
|
||||
|
||||
assert not isinstance(act_space, DiscreteTensorSpec), "TRPL does not support discrete action spaces"
|
||||
# Get spaces from specs for parallel env
|
||||
self.obs_space = env.observation_spec
|
||||
self.act_space = env.action_spec
|
||||
|
||||
self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
|
||||
self.raw_actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
|
||||
self.old_actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
|
||||
assert not isinstance(self.act_space, DiscreteTensorSpec), "TRPL does not support discrete action spaces"
|
||||
|
||||
self.critic = Critic(self.obs_space, critic_hidden_sizes, critic_activation_fn, device)
|
||||
self.raw_actor = Actor(self.obs_space, self.act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
|
||||
self.old_actor = Actor(self.obs_space, self.act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
|
||||
|
||||
# Handle projection_class
|
||||
if isinstance(projection_class, str):
|
||||
|
@ -4,6 +4,7 @@ from torchrl.modules import MLP
|
||||
from torchrl.data.tensor_specs import DiscreteTensorSpec
|
||||
from tensordict.nn.distributions import NormalParamExtractor
|
||||
from tensordict import TensorDict
|
||||
from torch.distributions import Categorical, MultivariateNormal, Normal
|
||||
|
||||
class Actor(TensorDictModule):
|
||||
def __init__(self, obs_space, act_space, hidden_sizes, activation_fn, device, full_covariance=False):
|
||||
@ -50,6 +51,17 @@ class Actor(TensorDictModule):
|
||||
out_keys=out_keys
|
||||
)
|
||||
|
||||
def get_dist(self, tensordict):
|
||||
# Forward the observation through the network
|
||||
out = self.forward(tensordict)
|
||||
if self.discrete:
|
||||
return Categorical(logits=out["logits"])
|
||||
else:
|
||||
if self.full_covariance:
|
||||
return MultivariateNormal(loc=out["loc"], scale_tril=out["scale_tril"])
|
||||
else:
|
||||
return Normal(loc=out["loc"], scale=out["scale"])
|
||||
|
||||
class FullCovarianceNormalParamExtractor(nn.Module):
|
||||
def __init__(self, action_dim):
|
||||
super().__init__()
|
||||
@ -65,8 +77,11 @@ class FullCovarianceNormalParamExtractor(nn.Module):
|
||||
|
||||
class Critic(TensorDictModule):
|
||||
def __init__(self, obs_space, hidden_sizes, activation_fn, device):
|
||||
obs_space = obs_space["observation"]
|
||||
obs_space_shape = obs_space.shape[1:]
|
||||
|
||||
critic_module = MLP(
|
||||
in_features=obs_space.shape[-1],
|
||||
in_features=obs_space_shape[0],
|
||||
out_features=1,
|
||||
num_cells=hidden_sizes,
|
||||
activation_class=getattr(nn, activation_fn),
|
||||
|
@ -1,5 +1,9 @@
|
||||
import torch
|
||||
import cpp_projection
|
||||
try:
|
||||
import cpp_projection
|
||||
cpp_projection_available = True
|
||||
except ImportError:
|
||||
cpp_projection_available = False
|
||||
import numpy as np
|
||||
from .base_projection import BaseProjection
|
||||
from tensordict.nn import TensorDictModule
|
||||
|
@ -8,7 +8,7 @@
|
||||
description = "Minimalistic and efficient implementations of PPO and TRPL for torchrl"
|
||||
authors = [{name = "Dominik Roth", email = "mail@dominik-roth.eu"}]
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.7,<3.12"
|
||||
requires-python = ">=3.7"
|
||||
classifiers = [
|
||||
"Development Status :: 3 - Alpha",
|
||||
"Intended Audience :: Developers",
|
||||
@ -23,9 +23,10 @@
|
||||
dependencies = [
|
||||
"numpy",
|
||||
"torch",
|
||||
"gymnasium",
|
||||
"gymnasium<1.0",
|
||||
"tensordict",
|
||||
"torchrl",
|
||||
"pytest",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
@ -33,3 +34,4 @@
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = ["pytest"]
|
||||
box2d = ["swig", "gymnasium[box2d]"]
|
||||
|
@ -3,9 +3,11 @@ import numpy as np
|
||||
from fancy_rl import PPO
|
||||
import gymnasium as gym
|
||||
from torchrl.envs import GymEnv
|
||||
import torch as th
|
||||
from tensordict import TensorDict
|
||||
|
||||
def simple_env():
|
||||
return GymEnv('LunarLander-v2', continuous=True)
|
||||
return gym.make('LunarLander-v2')
|
||||
|
||||
def test_ppo_instantiation():
|
||||
ppo = PPO(simple_env)
|
||||
@ -15,69 +17,38 @@ def test_ppo_instantiation_from_str():
|
||||
ppo = PPO('CartPole-v1')
|
||||
assert isinstance(ppo, PPO)
|
||||
|
||||
def test_ppo_instantiation_from_make():
|
||||
ppo = PPO(gym.make('CartPole-v1'))
|
||||
assert isinstance(ppo, PPO)
|
||||
|
||||
@pytest.mark.parametrize("learning_rate", [1e-4, 3e-4, 1e-3])
|
||||
@pytest.mark.parametrize("n_steps", [1024, 2048])
|
||||
@pytest.mark.parametrize("batch_size", [32, 64, 128])
|
||||
@pytest.mark.parametrize("n_epochs", [5, 10])
|
||||
@pytest.mark.parametrize("gamma", [0.95, 0.99])
|
||||
@pytest.mark.parametrize("clip_range", [0.1, 0.2, 0.3])
|
||||
def test_ppo_initialization_with_different_hps(learning_rate, n_steps, batch_size, n_epochs, gamma, clip_range):
|
||||
ppo = PPO(
|
||||
simple_env,
|
||||
learning_rate=learning_rate,
|
||||
n_steps=n_steps,
|
||||
batch_size=batch_size,
|
||||
n_epochs=n_epochs,
|
||||
gamma=gamma,
|
||||
clip_range=clip_range
|
||||
)
|
||||
assert ppo.learning_rate == learning_rate
|
||||
assert ppo.n_steps == n_steps
|
||||
assert ppo.batch_size == batch_size
|
||||
assert ppo.n_epochs == n_epochs
|
||||
assert ppo.gamma == gamma
|
||||
assert ppo.clip_range == clip_range
|
||||
|
||||
def test_ppo_predict():
|
||||
ppo = PPO(simple_env)
|
||||
env = ppo.make_env()
|
||||
obs, _ = env.reset()
|
||||
action, _ = ppo.predict(obs)
|
||||
assert isinstance(action, np.ndarray)
|
||||
assert action.shape == env.action_space.shape
|
||||
|
||||
def test_ppo_learn():
|
||||
ppo = PPO(simple_env, n_steps=64, batch_size=32)
|
||||
env = ppo.make_env()
|
||||
obs = env.reset()
|
||||
for _ in range(64):
|
||||
action, _next_state = ppo.predict(obs)
|
||||
obs, reward, done, truncated, _ = env.step(action)
|
||||
if done or truncated:
|
||||
obs = env.reset()
|
||||
action = ppo.predict(obs)
|
||||
assert isinstance(action, TensorDict)
|
||||
|
||||
# Handle both single and composite action spaces
|
||||
if isinstance(env.action_space, list):
|
||||
expected_shape = (len(env.action_space),) + env.action_space[0].shape
|
||||
else:
|
||||
expected_shape = env.action_space.shape
|
||||
|
||||
assert action["action"].shape == expected_shape
|
||||
|
||||
def test_ppo_training():
|
||||
ppo = PPO(simple_env, total_timesteps=10000)
|
||||
ppo = PPO(simple_env, total_timesteps=100)
|
||||
env = ppo.make_env()
|
||||
|
||||
initial_performance = evaluate_policy(ppo, env)
|
||||
ppo.train()
|
||||
final_performance = evaluate_policy(ppo, env)
|
||||
|
||||
assert final_performance > initial_performance, "PPO should improve performance after training"
|
||||
|
||||
def evaluate_policy(policy, env, n_eval_episodes=10):
|
||||
def evaluate_policy(policy, env, n_eval_episodes=3):
|
||||
total_reward = 0
|
||||
for _ in range(n_eval_episodes):
|
||||
obs = env.reset()
|
||||
tensordict = env.reset()
|
||||
done = False
|
||||
while not done:
|
||||
action, _next_state = policy.predict(obs)
|
||||
obs, reward, terminated, truncated, _ = env.step(action)
|
||||
total_reward += reward
|
||||
done = terminated or truncated
|
||||
action = policy.predict(tensordict)
|
||||
next_tensordict = env.step(action).get("next")
|
||||
total_reward += next_tensordict["reward"]
|
||||
done = next_tensordict["done"]
|
||||
tensordict = next_tensordict
|
||||
return total_reward / n_eval_episodes
|
@ -2,9 +2,10 @@ import pytest
|
||||
import numpy as np
|
||||
from fancy_rl import TRPL
|
||||
import gymnasium as gym
|
||||
from tensordict import TensorDict
|
||||
|
||||
def simple_env():
|
||||
return gym.make('LunarLander-v2', continuous=True)
|
||||
return gym.make('Pendulum-v1')
|
||||
|
||||
def test_trpl_instantiation():
|
||||
trpl = TRPL(simple_env)
|
||||
@ -34,52 +35,41 @@ def test_trpl_initialization_with_different_hps(learning_rate, n_steps, batch_si
|
||||
assert trpl.n_steps == n_steps
|
||||
assert trpl.batch_size == batch_size
|
||||
assert trpl.gamma == gamma
|
||||
assert trpl.projection.trust_region_bound_mean == trust_region_bound_mean
|
||||
assert trpl.projection.trust_region_bound_cov == trust_region_bound_cov
|
||||
assert trpl.projection.mean_bound == trust_region_bound_mean
|
||||
assert trpl.projection.cov_bound == trust_region_bound_cov
|
||||
|
||||
def test_trpl_predict():
|
||||
trpl = TRPL(simple_env)
|
||||
env = trpl.make_env()
|
||||
obs, _ = env.reset()
|
||||
action, _ = trpl.predict(obs)
|
||||
assert isinstance(action, np.ndarray)
|
||||
assert action.shape == env.action_space.shape
|
||||
obs = env.reset()
|
||||
action = trpl.predict(obs)
|
||||
assert isinstance(action, TensorDict)
|
||||
|
||||
def test_trpl_learn():
|
||||
trpl = TRPL(simple_env, n_steps=64, batch_size=32)
|
||||
env = trpl.make_env()
|
||||
obs, _ = env.reset()
|
||||
for _ in range(64):
|
||||
action, _ = trpl.predict(obs)
|
||||
next_obs, reward, done, truncated, _ = env.step(action)
|
||||
trpl.store_transition(obs, action, reward, done, next_obs)
|
||||
obs = next_obs
|
||||
if done or truncated:
|
||||
obs, _ = env.reset()
|
||||
# Handle both single and composite action spaces
|
||||
if isinstance(env.action_space, list):
|
||||
expected_shape = (len(env.action_space),) + env.action_space[0].shape
|
||||
else:
|
||||
expected_shape = env.action_space.shape
|
||||
|
||||
loss = trpl.learn()
|
||||
assert isinstance(loss, dict)
|
||||
assert "policy_loss" in loss
|
||||
assert "value_loss" in loss
|
||||
assert action["action"].shape == expected_shape
|
||||
|
||||
def test_trpl_training():
|
||||
trpl = TRPL(simple_env, total_timesteps=10000)
|
||||
trpl = TRPL(simple_env, total_timesteps=100)
|
||||
env = trpl.make_env()
|
||||
|
||||
initial_performance = evaluate_policy(trpl, env)
|
||||
trpl.train()
|
||||
final_performance = evaluate_policy(trpl, env)
|
||||
|
||||
assert final_performance > initial_performance, "TRPL should improve performance after training"
|
||||
|
||||
def evaluate_policy(policy, env, n_eval_episodes=10):
|
||||
def evaluate_policy(policy, env, n_eval_episodes=3):
|
||||
total_reward = 0
|
||||
for _ in range(n_eval_episodes):
|
||||
obs, _ = env.reset()
|
||||
tensordict = env.reset()
|
||||
done = False
|
||||
while not done:
|
||||
action, _ = policy.predict(obs)
|
||||
obs, reward, terminated, truncated, _ = env.step(action)
|
||||
total_reward += reward
|
||||
done = terminated or truncated
|
||||
action = policy.predict(tensordict)
|
||||
next_tensordict = env.step(action).get("next")
|
||||
total_reward += next_tensordict["reward"]
|
||||
done = next_tensordict["done"]
|
||||
tensordict = next_tensordict
|
||||
return total_reward / n_eval_episodes
|
Loading…
Reference in New Issue
Block a user