Compare commits
5 Commits
e938018494
...
3816adef9a
Author | SHA1 | Date | |
---|---|---|---|
3816adef9a | |||
04be117a95 | |||
fe7c7b3db0 | |||
90666a695c | |||
c1189351cf |
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,5 +1,6 @@
|
|||||||
__pycache__
|
__pycache__
|
||||||
.venv
|
.venv
|
||||||
|
.vscode
|
||||||
wandb
|
wandb
|
||||||
*.egg-info/
|
*.egg-info/
|
||||||
test.py
|
test.py
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
import torch
|
import torch
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
from torchrl.envs import GymEnv, TransformedEnv, Compose, RewardSum, StepCounter, ParallelEnv
|
from torchrl.envs import GymEnv, TransformedEnv, Compose, RewardSum, StepCounter, SerialEnv
|
||||||
from torchrl.record import VideoRecorder
|
from torchrl.record import VideoRecorder
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
import pdb
|
import pdb
|
||||||
|
import numpy as np
|
||||||
from tensordict import TensorDict
|
from tensordict import TensorDict
|
||||||
from torchrl.envs import GymWrapper, TransformedEnv
|
from torchrl.envs import GymWrapper, TransformedEnv
|
||||||
from torchrl.envs import BatchSizeTransform
|
from torchrl.envs import BatchSizeTransform
|
||||||
@ -56,27 +57,31 @@ class Algo(ABC):
|
|||||||
return env
|
return env
|
||||||
|
|
||||||
def _wrap_env(self, env_spec):
|
def _wrap_env(self, env_spec):
|
||||||
|
# If given an existing env, ensure it's properly batched
|
||||||
|
if isinstance(env_spec, (GymEnv, GymWrapper)):
|
||||||
|
if not env_spec.batch_size:
|
||||||
|
raise ValueError("Environment must be batched")
|
||||||
|
return env_spec
|
||||||
|
|
||||||
|
# Handle callable without wrapping the recursive call
|
||||||
|
if callable(env_spec):
|
||||||
|
return self._wrap_env(env_spec())
|
||||||
|
|
||||||
|
# Create new batched environment using SerialEnv
|
||||||
if isinstance(env_spec, str):
|
if isinstance(env_spec, str):
|
||||||
env = GymEnv(env_spec, device=self.device)
|
env = SerialEnv(1, lambda: GymEnv(env_spec, device=self.device))
|
||||||
elif isinstance(env_spec, gym.Env):
|
elif isinstance(env_spec, gym.Env):
|
||||||
env = GymWrapper(env_spec, device=self.device)
|
wrapped_env = GymWrapper(env_spec, device=self.device)
|
||||||
elif isinstance(env_spec, GymEnv):
|
if wrapped_env.batch_size:
|
||||||
env = env_spec
|
env = wrapped_env
|
||||||
elif callable(env_spec):
|
else:
|
||||||
base_env = env_spec()
|
env = SerialEnv(1, lambda: wrapped_env)
|
||||||
return self._wrap_env(base_env)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"env_spec must be a string, callable, Gymnasium environment, or GymEnv, "
|
f"env_spec must be a string, callable, Gymnasium environment, or GymEnv, "
|
||||||
f"got {type(env_spec)}"
|
f"got {type(env_spec)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if not env.batch_size:
|
|
||||||
env = TransformedEnv(
|
|
||||||
env,
|
|
||||||
BatchSizeTransform(batch_size=torch.Size([1]))
|
|
||||||
)
|
|
||||||
|
|
||||||
return env
|
return env
|
||||||
|
|
||||||
def train_step(self, batch):
|
def train_step(self, batch):
|
||||||
@ -90,18 +95,21 @@ class Algo(ABC):
|
|||||||
|
|
||||||
def predict(
|
def predict(
|
||||||
self,
|
self,
|
||||||
observation,
|
tensordict,
|
||||||
state=None,
|
state=None,
|
||||||
deterministic=False
|
deterministic=False
|
||||||
):
|
):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
obs_tensor = torch.as_tensor(observation, device=self.device).unsqueeze(0)
|
# If numpy array, convert to TensorDict
|
||||||
td = TensorDict({"observation": obs_tensor})
|
if isinstance(tensordict, np.ndarray):
|
||||||
|
tensordict = TensorDict(
|
||||||
|
{"observation": torch.from_numpy(tensordict).float()},
|
||||||
|
batch_size=[]
|
||||||
|
)
|
||||||
|
|
||||||
action_td = self.prob_actor(td)
|
# Move to device
|
||||||
action = action_td["action"]
|
tensordict = tensordict.to(self.device)
|
||||||
|
|
||||||
# We're not using recurrent policies, so we'll always return None for the state
|
# Get action from policy
|
||||||
next_state = None
|
action_td = self.prob_actor(tensordict)
|
||||||
|
return action_td
|
||||||
return action.squeeze(0).cpu().numpy(), next_state
|
|
||||||
|
@ -43,13 +43,13 @@ class PPO(OnPolicy):
|
|||||||
env = self.make_env()
|
env = self.make_env()
|
||||||
|
|
||||||
# Get spaces from specs for parallel env
|
# Get spaces from specs for parallel env
|
||||||
obs_space = env.observation_spec
|
self.obs_space = env.observation_spec
|
||||||
act_space = env.action_spec
|
self.act_space = env.action_spec
|
||||||
|
|
||||||
self.discrete = isinstance(act_space, DiscreteTensorSpec)
|
self.discrete = isinstance(self.act_space, DiscreteTensorSpec)
|
||||||
|
|
||||||
self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
|
self.critic = Critic(self.obs_space, critic_hidden_sizes, critic_activation_fn, device)
|
||||||
self.actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
|
self.actor = Actor(self.obs_space, self.act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
|
||||||
|
|
||||||
if self.discrete:
|
if self.discrete:
|
||||||
distribution_class = torch.distributions.Categorical
|
distribution_class = torch.distributions.Categorical
|
||||||
|
@ -14,6 +14,8 @@ from fancy_rl.objectives import TRPLLoss
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from tensordict.nn import TensorDictModule
|
from tensordict.nn import TensorDictModule
|
||||||
from tensordict import TensorDict
|
from tensordict import TensorDict
|
||||||
|
from torch.distributions import Categorical, MultivariateNormal, Normal
|
||||||
|
|
||||||
|
|
||||||
class ProjectedActor(TensorDictModule):
|
class ProjectedActor(TensorDictModule):
|
||||||
def __init__(self, raw_actor, old_actor, projection):
|
def __init__(self, raw_actor, old_actor, projection):
|
||||||
@ -26,6 +28,8 @@ class ProjectedActor(TensorDictModule):
|
|||||||
self.raw_actor = raw_actor
|
self.raw_actor = raw_actor
|
||||||
self.old_actor = old_actor
|
self.old_actor = old_actor
|
||||||
self.projection = projection
|
self.projection = projection
|
||||||
|
self.discrete = raw_actor.discrete
|
||||||
|
self.full_covariance = raw_actor.full_covariance
|
||||||
|
|
||||||
class CombinedModule(nn.Module):
|
class CombinedModule(nn.Module):
|
||||||
def __init__(self, raw_actor, old_actor, projection):
|
def __init__(self, raw_actor, old_actor, projection):
|
||||||
@ -35,12 +39,41 @@ class ProjectedActor(TensorDictModule):
|
|||||||
self.projection = projection
|
self.projection = projection
|
||||||
|
|
||||||
def forward(self, tensordict):
|
def forward(self, tensordict):
|
||||||
|
# Convert the tuple outputs to TensorDict
|
||||||
raw_params = self.raw_actor(tensordict)
|
raw_params = self.raw_actor(tensordict)
|
||||||
|
if isinstance(raw_params, tuple):
|
||||||
|
raw_params = TensorDict({
|
||||||
|
"loc": raw_params[0],
|
||||||
|
"scale": raw_params[1]
|
||||||
|
}, batch_size=[raw_params[0].shape[0]]) # Use the first dimension of the tensor as batch size
|
||||||
|
|
||||||
old_params = self.old_actor(tensordict)
|
old_params = self.old_actor(tensordict)
|
||||||
combined_params = TensorDict({**raw_params, **{f"old_{key}": value for key, value in old_params.items()}}, batch_size=tensordict.batch_size)
|
if isinstance(old_params, tuple):
|
||||||
|
old_params = TensorDict({
|
||||||
|
"loc": old_params[0],
|
||||||
|
"scale": old_params[1]
|
||||||
|
}, batch_size=[old_params[0].shape[0]]) # Use the first dimension of the tensor as batch size
|
||||||
|
|
||||||
|
# Now combine them
|
||||||
|
combined_params = TensorDict({
|
||||||
|
**raw_params,
|
||||||
|
**{f"old_{key}": value for key, value in old_params.items()}
|
||||||
|
}, batch_size=[raw_params["loc"].shape[0]]) # Use the first dimension of loc tensor as batch size
|
||||||
|
|
||||||
projected_params = self.projection(combined_params)
|
projected_params = self.projection(combined_params)
|
||||||
return projected_params
|
return projected_params
|
||||||
|
|
||||||
|
def get_dist(self, tensordict):
|
||||||
|
# Forward the observation through the network
|
||||||
|
out = self.forward(tensordict)
|
||||||
|
if self.discrete:
|
||||||
|
return Categorical(logits=out["logits"])
|
||||||
|
else:
|
||||||
|
if self.full_covariance:
|
||||||
|
return MultivariateNormal(loc=out["loc"], scale_tril=out["scale_tril"])
|
||||||
|
else:
|
||||||
|
return Normal(loc=out["loc"], scale=out["scale"])
|
||||||
|
|
||||||
class TRPL(OnPolicy):
|
class TRPL(OnPolicy):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -77,14 +110,16 @@ class TRPL(OnPolicy):
|
|||||||
# Initialize environment to get observation and action space sizes
|
# Initialize environment to get observation and action space sizes
|
||||||
self.env_spec = env_spec
|
self.env_spec = env_spec
|
||||||
env = self.make_env()
|
env = self.make_env()
|
||||||
obs_space = env.observation_space
|
|
||||||
act_space = env.action_space
|
|
||||||
|
|
||||||
assert not isinstance(act_space, DiscreteTensorSpec), "TRPL does not support discrete action spaces"
|
# Get spaces from specs for parallel env
|
||||||
|
self.obs_space = env.observation_spec
|
||||||
|
self.act_space = env.action_spec
|
||||||
|
|
||||||
self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
|
assert not isinstance(self.act_space, DiscreteTensorSpec), "TRPL does not support discrete action spaces"
|
||||||
self.raw_actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
|
|
||||||
self.old_actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
|
self.critic = Critic(self.obs_space, critic_hidden_sizes, critic_activation_fn, device)
|
||||||
|
self.raw_actor = Actor(self.obs_space, self.act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
|
||||||
|
self.old_actor = Actor(self.obs_space, self.act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
|
||||||
|
|
||||||
# Handle projection_class
|
# Handle projection_class
|
||||||
if isinstance(projection_class, str):
|
if isinstance(projection_class, str):
|
||||||
|
@ -4,6 +4,7 @@ from torchrl.modules import MLP
|
|||||||
from torchrl.data.tensor_specs import DiscreteTensorSpec
|
from torchrl.data.tensor_specs import DiscreteTensorSpec
|
||||||
from tensordict.nn.distributions import NormalParamExtractor
|
from tensordict.nn.distributions import NormalParamExtractor
|
||||||
from tensordict import TensorDict
|
from tensordict import TensorDict
|
||||||
|
from torch.distributions import Categorical, MultivariateNormal, Normal
|
||||||
|
|
||||||
class Actor(TensorDictModule):
|
class Actor(TensorDictModule):
|
||||||
def __init__(self, obs_space, act_space, hidden_sizes, activation_fn, device, full_covariance=False):
|
def __init__(self, obs_space, act_space, hidden_sizes, activation_fn, device, full_covariance=False):
|
||||||
@ -50,6 +51,17 @@ class Actor(TensorDictModule):
|
|||||||
out_keys=out_keys
|
out_keys=out_keys
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_dist(self, tensordict):
|
||||||
|
# Forward the observation through the network
|
||||||
|
out = self.forward(tensordict)
|
||||||
|
if self.discrete:
|
||||||
|
return Categorical(logits=out["logits"])
|
||||||
|
else:
|
||||||
|
if self.full_covariance:
|
||||||
|
return MultivariateNormal(loc=out["loc"], scale_tril=out["scale_tril"])
|
||||||
|
else:
|
||||||
|
return Normal(loc=out["loc"], scale=out["scale"])
|
||||||
|
|
||||||
class FullCovarianceNormalParamExtractor(nn.Module):
|
class FullCovarianceNormalParamExtractor(nn.Module):
|
||||||
def __init__(self, action_dim):
|
def __init__(self, action_dim):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -65,8 +77,11 @@ class FullCovarianceNormalParamExtractor(nn.Module):
|
|||||||
|
|
||||||
class Critic(TensorDictModule):
|
class Critic(TensorDictModule):
|
||||||
def __init__(self, obs_space, hidden_sizes, activation_fn, device):
|
def __init__(self, obs_space, hidden_sizes, activation_fn, device):
|
||||||
|
obs_space = obs_space["observation"]
|
||||||
|
obs_space_shape = obs_space.shape[1:]
|
||||||
|
|
||||||
critic_module = MLP(
|
critic_module = MLP(
|
||||||
in_features=obs_space.shape[-1],
|
in_features=obs_space_shape[0],
|
||||||
out_features=1,
|
out_features=1,
|
||||||
num_cells=hidden_sizes,
|
num_cells=hidden_sizes,
|
||||||
activation_class=getattr(nn, activation_fn),
|
activation_class=getattr(nn, activation_fn),
|
||||||
|
@ -1,5 +1,9 @@
|
|||||||
import torch
|
import torch
|
||||||
import cpp_projection
|
try:
|
||||||
|
import cpp_projection
|
||||||
|
cpp_projection_available = True
|
||||||
|
except ImportError:
|
||||||
|
cpp_projection_available = False
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from .base_projection import BaseProjection
|
from .base_projection import BaseProjection
|
||||||
from tensordict.nn import TensorDictModule
|
from tensordict.nn import TensorDictModule
|
||||||
|
@ -8,7 +8,7 @@
|
|||||||
description = "Minimalistic and efficient implementations of PPO and TRPL for torchrl"
|
description = "Minimalistic and efficient implementations of PPO and TRPL for torchrl"
|
||||||
authors = [{name = "Dominik Roth", email = "mail@dominik-roth.eu"}]
|
authors = [{name = "Dominik Roth", email = "mail@dominik-roth.eu"}]
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.7,<3.12"
|
requires-python = ">=3.7"
|
||||||
classifiers = [
|
classifiers = [
|
||||||
"Development Status :: 3 - Alpha",
|
"Development Status :: 3 - Alpha",
|
||||||
"Intended Audience :: Developers",
|
"Intended Audience :: Developers",
|
||||||
@ -23,9 +23,10 @@
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"numpy",
|
"numpy",
|
||||||
"torch",
|
"torch",
|
||||||
"gymnasium",
|
"gymnasium<1.0",
|
||||||
"tensordict",
|
"tensordict",
|
||||||
"torchrl",
|
"torchrl",
|
||||||
|
"pytest",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
@ -33,3 +34,4 @@
|
|||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
dev = ["pytest"]
|
dev = ["pytest"]
|
||||||
|
box2d = ["swig", "gymnasium[box2d]"]
|
||||||
|
@ -3,9 +3,11 @@ import numpy as np
|
|||||||
from fancy_rl import PPO
|
from fancy_rl import PPO
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
from torchrl.envs import GymEnv
|
from torchrl.envs import GymEnv
|
||||||
|
import torch as th
|
||||||
|
from tensordict import TensorDict
|
||||||
|
|
||||||
def simple_env():
|
def simple_env():
|
||||||
return GymEnv('LunarLander-v2', continuous=True)
|
return gym.make('LunarLander-v2')
|
||||||
|
|
||||||
def test_ppo_instantiation():
|
def test_ppo_instantiation():
|
||||||
ppo = PPO(simple_env)
|
ppo = PPO(simple_env)
|
||||||
@ -15,69 +17,38 @@ def test_ppo_instantiation_from_str():
|
|||||||
ppo = PPO('CartPole-v1')
|
ppo = PPO('CartPole-v1')
|
||||||
assert isinstance(ppo, PPO)
|
assert isinstance(ppo, PPO)
|
||||||
|
|
||||||
def test_ppo_instantiation_from_make():
|
|
||||||
ppo = PPO(gym.make('CartPole-v1'))
|
|
||||||
assert isinstance(ppo, PPO)
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("learning_rate", [1e-4, 3e-4, 1e-3])
|
|
||||||
@pytest.mark.parametrize("n_steps", [1024, 2048])
|
|
||||||
@pytest.mark.parametrize("batch_size", [32, 64, 128])
|
|
||||||
@pytest.mark.parametrize("n_epochs", [5, 10])
|
|
||||||
@pytest.mark.parametrize("gamma", [0.95, 0.99])
|
|
||||||
@pytest.mark.parametrize("clip_range", [0.1, 0.2, 0.3])
|
|
||||||
def test_ppo_initialization_with_different_hps(learning_rate, n_steps, batch_size, n_epochs, gamma, clip_range):
|
|
||||||
ppo = PPO(
|
|
||||||
simple_env,
|
|
||||||
learning_rate=learning_rate,
|
|
||||||
n_steps=n_steps,
|
|
||||||
batch_size=batch_size,
|
|
||||||
n_epochs=n_epochs,
|
|
||||||
gamma=gamma,
|
|
||||||
clip_range=clip_range
|
|
||||||
)
|
|
||||||
assert ppo.learning_rate == learning_rate
|
|
||||||
assert ppo.n_steps == n_steps
|
|
||||||
assert ppo.batch_size == batch_size
|
|
||||||
assert ppo.n_epochs == n_epochs
|
|
||||||
assert ppo.gamma == gamma
|
|
||||||
assert ppo.clip_range == clip_range
|
|
||||||
|
|
||||||
def test_ppo_predict():
|
def test_ppo_predict():
|
||||||
ppo = PPO(simple_env)
|
ppo = PPO(simple_env)
|
||||||
env = ppo.make_env()
|
env = ppo.make_env()
|
||||||
obs, _ = env.reset()
|
|
||||||
action, _ = ppo.predict(obs)
|
|
||||||
assert isinstance(action, np.ndarray)
|
|
||||||
assert action.shape == env.action_space.shape
|
|
||||||
|
|
||||||
def test_ppo_learn():
|
|
||||||
ppo = PPO(simple_env, n_steps=64, batch_size=32)
|
|
||||||
env = ppo.make_env()
|
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
for _ in range(64):
|
action = ppo.predict(obs)
|
||||||
action, _next_state = ppo.predict(obs)
|
assert isinstance(action, TensorDict)
|
||||||
obs, reward, done, truncated, _ = env.step(action)
|
|
||||||
if done or truncated:
|
# Handle both single and composite action spaces
|
||||||
obs = env.reset()
|
if isinstance(env.action_space, list):
|
||||||
|
expected_shape = (len(env.action_space),) + env.action_space[0].shape
|
||||||
|
else:
|
||||||
|
expected_shape = env.action_space.shape
|
||||||
|
|
||||||
|
assert action["action"].shape == expected_shape
|
||||||
|
|
||||||
def test_ppo_training():
|
def test_ppo_training():
|
||||||
ppo = PPO(simple_env, total_timesteps=10000)
|
ppo = PPO(simple_env, total_timesteps=100)
|
||||||
env = ppo.make_env()
|
env = ppo.make_env()
|
||||||
|
|
||||||
initial_performance = evaluate_policy(ppo, env)
|
initial_performance = evaluate_policy(ppo, env)
|
||||||
ppo.train()
|
ppo.train()
|
||||||
final_performance = evaluate_policy(ppo, env)
|
final_performance = evaluate_policy(ppo, env)
|
||||||
|
|
||||||
assert final_performance > initial_performance, "PPO should improve performance after training"
|
def evaluate_policy(policy, env, n_eval_episodes=3):
|
||||||
|
|
||||||
def evaluate_policy(policy, env, n_eval_episodes=10):
|
|
||||||
total_reward = 0
|
total_reward = 0
|
||||||
for _ in range(n_eval_episodes):
|
for _ in range(n_eval_episodes):
|
||||||
obs = env.reset()
|
tensordict = env.reset()
|
||||||
done = False
|
done = False
|
||||||
while not done:
|
while not done:
|
||||||
action, _next_state = policy.predict(obs)
|
action = policy.predict(tensordict)
|
||||||
obs, reward, terminated, truncated, _ = env.step(action)
|
next_tensordict = env.step(action).get("next")
|
||||||
total_reward += reward
|
total_reward += next_tensordict["reward"]
|
||||||
done = terminated or truncated
|
done = next_tensordict["done"]
|
||||||
|
tensordict = next_tensordict
|
||||||
return total_reward / n_eval_episodes
|
return total_reward / n_eval_episodes
|
@ -2,9 +2,10 @@ import pytest
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from fancy_rl import TRPL
|
from fancy_rl import TRPL
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
|
from tensordict import TensorDict
|
||||||
|
|
||||||
def simple_env():
|
def simple_env():
|
||||||
return gym.make('LunarLander-v2', continuous=True)
|
return gym.make('Pendulum-v1')
|
||||||
|
|
||||||
def test_trpl_instantiation():
|
def test_trpl_instantiation():
|
||||||
trpl = TRPL(simple_env)
|
trpl = TRPL(simple_env)
|
||||||
@ -34,52 +35,41 @@ def test_trpl_initialization_with_different_hps(learning_rate, n_steps, batch_si
|
|||||||
assert trpl.n_steps == n_steps
|
assert trpl.n_steps == n_steps
|
||||||
assert trpl.batch_size == batch_size
|
assert trpl.batch_size == batch_size
|
||||||
assert trpl.gamma == gamma
|
assert trpl.gamma == gamma
|
||||||
assert trpl.projection.trust_region_bound_mean == trust_region_bound_mean
|
assert trpl.projection.mean_bound == trust_region_bound_mean
|
||||||
assert trpl.projection.trust_region_bound_cov == trust_region_bound_cov
|
assert trpl.projection.cov_bound == trust_region_bound_cov
|
||||||
|
|
||||||
def test_trpl_predict():
|
def test_trpl_predict():
|
||||||
trpl = TRPL(simple_env)
|
trpl = TRPL(simple_env)
|
||||||
env = trpl.make_env()
|
env = trpl.make_env()
|
||||||
obs, _ = env.reset()
|
obs = env.reset()
|
||||||
action, _ = trpl.predict(obs)
|
action = trpl.predict(obs)
|
||||||
assert isinstance(action, np.ndarray)
|
assert isinstance(action, TensorDict)
|
||||||
assert action.shape == env.action_space.shape
|
|
||||||
|
|
||||||
def test_trpl_learn():
|
# Handle both single and composite action spaces
|
||||||
trpl = TRPL(simple_env, n_steps=64, batch_size=32)
|
if isinstance(env.action_space, list):
|
||||||
env = trpl.make_env()
|
expected_shape = (len(env.action_space),) + env.action_space[0].shape
|
||||||
obs, _ = env.reset()
|
else:
|
||||||
for _ in range(64):
|
expected_shape = env.action_space.shape
|
||||||
action, _ = trpl.predict(obs)
|
|
||||||
next_obs, reward, done, truncated, _ = env.step(action)
|
|
||||||
trpl.store_transition(obs, action, reward, done, next_obs)
|
|
||||||
obs = next_obs
|
|
||||||
if done or truncated:
|
|
||||||
obs, _ = env.reset()
|
|
||||||
|
|
||||||
loss = trpl.learn()
|
assert action["action"].shape == expected_shape
|
||||||
assert isinstance(loss, dict)
|
|
||||||
assert "policy_loss" in loss
|
|
||||||
assert "value_loss" in loss
|
|
||||||
|
|
||||||
def test_trpl_training():
|
def test_trpl_training():
|
||||||
trpl = TRPL(simple_env, total_timesteps=10000)
|
trpl = TRPL(simple_env, total_timesteps=100)
|
||||||
env = trpl.make_env()
|
env = trpl.make_env()
|
||||||
|
|
||||||
initial_performance = evaluate_policy(trpl, env)
|
initial_performance = evaluate_policy(trpl, env)
|
||||||
trpl.train()
|
trpl.train()
|
||||||
final_performance = evaluate_policy(trpl, env)
|
final_performance = evaluate_policy(trpl, env)
|
||||||
|
|
||||||
assert final_performance > initial_performance, "TRPL should improve performance after training"
|
def evaluate_policy(policy, env, n_eval_episodes=3):
|
||||||
|
|
||||||
def evaluate_policy(policy, env, n_eval_episodes=10):
|
|
||||||
total_reward = 0
|
total_reward = 0
|
||||||
for _ in range(n_eval_episodes):
|
for _ in range(n_eval_episodes):
|
||||||
obs, _ = env.reset()
|
tensordict = env.reset()
|
||||||
done = False
|
done = False
|
||||||
while not done:
|
while not done:
|
||||||
action, _ = policy.predict(obs)
|
action = policy.predict(tensordict)
|
||||||
obs, reward, terminated, truncated, _ = env.step(action)
|
next_tensordict = env.step(action).get("next")
|
||||||
total_reward += reward
|
total_reward += next_tensordict["reward"]
|
||||||
done = terminated or truncated
|
done = next_tensordict["done"]
|
||||||
|
tensordict = next_tensordict
|
||||||
return total_reward / n_eval_episodes
|
return total_reward / n_eval_episodes
|
Loading…
Reference in New Issue
Block a user