Compare commits

...

3 Commits

4 changed files with 54 additions and 90 deletions

View File

@ -6,13 +6,12 @@ from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.envs.libs.gym import GymWrapper from torchrl.envs.libs.gym import GymWrapper
from torchrl.envs import ExplorationType, set_exploration_type from torchrl.envs import ExplorationType, set_exploration_type
from torchrl.record import VideoRecorder from torchrl.record import VideoRecorder
from abc import ABC, abstractmethod from tensordict import LazyStackedTensorDict, TensorDict
from abc import ABC
class OnPolicy(ABC): class OnPolicy(ABC):
def __init__( def __init__(
self, self,
policy,
env_spec, env_spec,
loggers, loggers,
optimizers, optimizers,
@ -21,19 +20,16 @@ class OnPolicy(ABC):
batch_size, batch_size,
n_epochs, n_epochs,
gamma, gamma,
gae_lambda,
total_timesteps, total_timesteps,
eval_interval, eval_interval,
eval_deterministic, eval_deterministic,
entropy_coef, entropy_coef,
critic_coef, critic_coef,
normalize_advantage, normalize_advantage,
clip_range=0.2,
device=None, device=None,
eval_episodes=10, eval_episodes=10,
env_spec_eval=None, env_spec_eval=None,
): ):
self.policy = policy
self.env_spec = env_spec self.env_spec = env_spec
self.env_spec_eval = env_spec_eval if env_spec_eval is not None else env_spec self.env_spec_eval = env_spec_eval if env_spec_eval is not None else env_spec
self.loggers = loggers self.loggers = loggers
@ -43,21 +39,19 @@ class OnPolicy(ABC):
self.batch_size = batch_size self.batch_size = batch_size
self.n_epochs = n_epochs self.n_epochs = n_epochs
self.gamma = gamma self.gamma = gamma
self.gae_lambda = gae_lambda
self.total_timesteps = total_timesteps self.total_timesteps = total_timesteps
self.eval_interval = eval_interval self.eval_interval = eval_interval
self.eval_deterministic = eval_deterministic self.eval_deterministic = eval_deterministic
self.entropy_coef = entropy_coef self.entropy_coef = entropy_coef
self.critic_coef = critic_coef self.critic_coef = critic_coef
self.normalize_advantage = normalize_advantage self.normalize_advantage = normalize_advantage
self.clip_range = clip_range
self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu") self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
self.eval_episodes = eval_episodes self.eval_episodes = eval_episodes
# Create collector # Create collector
self.collector = SyncDataCollector( self.collector = SyncDataCollector(
create_env_fn=lambda: self.make_env(eval=False), create_env_fn=lambda: self.make_env(eval=False),
policy=self.policy, policy=self.actor,
frames_per_batch=self.n_steps, frames_per_batch=self.n_steps,
total_frames=self.total_timesteps, total_frames=self.total_timesteps,
device=self.device, device=self.device,
@ -78,13 +72,13 @@ class OnPolicy(ABC):
env_spec = self.env_spec_eval if eval else self.env_spec env_spec = self.env_spec_eval if eval else self.env_spec
if isinstance(env_spec, str): if isinstance(env_spec, str):
env = gym.make(env_spec) env = gym.make(env_spec)
env = GymWrapper(env) env = GymWrapper(env).to(self.device)
elif callable(env_spec): elif callable(env_spec):
env = env_spec() env = env_spec()
if isinstance(env, gym.Env): if isinstance(env, gym.Env):
env = GymWrapper(env) env = GymWrapper(env).to(self.device)
elif isinstance(env, gym.Env): elif isinstance(env, gym.Env):
env = GymWrapper(env) env = GymWrapper(env).to(self.device)
else: else:
raise ValueError("env_spec must be a string or a callable that returns an environment.") raise ValueError("env_spec must be a string or a callable that returns an environment.")
return env return env
@ -92,7 +86,8 @@ class OnPolicy(ABC):
def train_step(self, batch): def train_step(self, batch):
for optimizer in self.optimizers.values(): for optimizer in self.optimizers.values():
optimizer.zero_grad() optimizer.zero_grad()
loss = self.loss_module(batch) losses = self.loss_module(batch)
loss = losses['loss_objective'] + losses["loss_entropy"] + losses["loss_critic"]
loss.backward() loss.backward()
for optimizer in self.optimizers.values(): for optimizer in self.optimizers.values():
optimizer.step() optimizer.step()
@ -130,7 +125,7 @@ class OnPolicy(ABC):
for _ in range(self.eval_episodes): for _ in range(self.eval_episodes):
with torch.no_grad(), set_exploration_type(ExplorationType.MODE): with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
td_test = eval_env.rollout( td_test = eval_env.rollout(
policy=self.policy, policy=self.actor,
auto_reset=True, auto_reset=True,
auto_cast_to_device=True, auto_cast_to_device=True,
break_when_any_done=True, break_when_any_done=True,

View File

@ -3,13 +3,13 @@ from torchrl.modules import ActorValueOperator, ProbabilisticActor
from torchrl.objectives import ClipPPOLoss from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value.advantages import GAE from torchrl.objectives.value.advantages import GAE
from fancy_rl.algos.on_policy import OnPolicy from fancy_rl.algos.on_policy import OnPolicy
from fancy_rl.policy import Actor, Critic, SharedModule from fancy_rl.policy import Actor, Critic
class PPO(OnPolicy): class PPO(OnPolicy):
def __init__( def __init__(
self, self,
env_spec, env_spec,
loggers=None, loggers=[],
actor_hidden_sizes=[64, 64], actor_hidden_sizes=[64, 64],
critic_hidden_sizes=[64, 64], critic_hidden_sizes=[64, 64],
actor_activation_fn="Tanh", actor_activation_fn="Tanh",
@ -33,6 +33,7 @@ class PPO(OnPolicy):
eval_episodes=10, eval_episodes=10,
): ):
device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.device = device
# Initialize environment to get observation and action space sizes # Initialize environment to get observation and action space sizes
self.env_spec = env_spec self.env_spec = env_spec
@ -40,21 +41,10 @@ class PPO(OnPolicy):
obs_space = env.observation_space obs_space = env.observation_space
act_space = env.action_space act_space = env.action_space
# Define the shared, actor, and critic modules self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
self.shared_module = SharedModule(obs_space, shared_stem_sizes, actor_activation_fn, device) actor_net = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device)
self.actor = Actor(self.shared_module, act_space, actor_hidden_sizes, actor_activation_fn, device) self.actor = ProbabilisticActor(
self.critic = Critic(self.shared_module, critic_hidden_sizes, critic_activation_fn, device) module=actor_net,
# Combine into an ActorValueOperator
self.ac_module = ActorValueOperator(
self.shared_module,
self.actor,
self.critic
)
# Define the policy as a ProbabilisticActor
policy = ProbabilisticActor(
module=self.ac_module.get_policy_operator(),
in_keys=["loc", "scale"], in_keys=["loc", "scale"],
out_keys=["action"], out_keys=["action"],
distribution_class=torch.distributions.Normal, distribution_class=torch.distributions.Normal,
@ -66,25 +56,7 @@ class PPO(OnPolicy):
"critic": torch.optim.Adam(self.critic.parameters(), lr=learning_rate) "critic": torch.optim.Adam(self.critic.parameters(), lr=learning_rate)
} }
self.adv_module = GAE(
gamma=self.gamma,
lmbda=self.gae_lambda,
value_network=self.critic,
average_gae=False,
)
self.loss_module = ClipPPOLoss(
actor_network=self.actor,
critic_network=self.critic,
clip_epsilon=self.clip_range,
loss_critic_type='MSELoss',
entropy_coef=self.entropy_coef,
critic_coef=self.critic_coef,
normalize_advantage=self.normalize_advantage,
)
super().__init__( super().__init__(
policy=policy,
env_spec=env_spec, env_spec=env_spec,
loggers=loggers, loggers=loggers,
optimizers=optimizers, optimizers=optimizers,
@ -93,15 +65,30 @@ class PPO(OnPolicy):
batch_size=batch_size, batch_size=batch_size,
n_epochs=n_epochs, n_epochs=n_epochs,
gamma=gamma, gamma=gamma,
gae_lambda=gae_lambda,
total_timesteps=total_timesteps, total_timesteps=total_timesteps,
eval_interval=eval_interval, eval_interval=eval_interval,
eval_deterministic=eval_deterministic, eval_deterministic=eval_deterministic,
entropy_coef=entropy_coef, entropy_coef=entropy_coef,
critic_coef=critic_coef, critic_coef=critic_coef,
normalize_advantage=normalize_advantage, normalize_advantage=normalize_advantage,
clip_range=clip_range,
device=device, device=device,
env_spec_eval=env_spec_eval, env_spec_eval=env_spec_eval,
eval_episodes=eval_episodes, eval_episodes=eval_episodes,
) )
self.adv_module = GAE(
gamma=self.gamma,
lmbda=gae_lambda,
value_network=self.critic,
average_gae=False,
)
self.loss_module = ClipPPOLoss(
actor_network=self.actor,
critic_network=self.critic,
clip_epsilon=clip_range,
loss_critic_type='l2',
entropy_coef=self.entropy_coef,
critic_coef=self.critic_coef,
normalize_advantage=self.normalize_advantage,
)

View File

@ -4,30 +4,8 @@ from torchrl.modules import MLP
from tensordict.nn.distributions import NormalParamExtractor from tensordict.nn.distributions import NormalParamExtractor
from fancy_rl.utils import is_discrete_space, get_space_shape from fancy_rl.utils import is_discrete_space, get_space_shape
class SharedModule(TensorDictModule):
def __init__(self, obs_space, hidden_sizes, activation_fn, device):
if hidden_sizes:
shared_module = MLP(
in_features=get_space_shape(obs_space)[-1],
out_features=hidden_sizes[-1],
num_cells=hidden_sizes[:-1],
activation_class=getattr(nn, activation_fn),
device=device
)
out_features = hidden_sizes[-1]
else:
shared_module = nn.Identity()
out_features = get_space_shape(obs_space)[-1]
super().__init__(
module=shared_module,
in_keys=["observation"],
out_keys=["shared"],
)
self.out_features = out_features
class Actor(TensorDictModule): class Actor(TensorDictModule):
def __init__(self, shared_module, act_space, hidden_sizes, activation_fn, device): def __init__(self, obs_space, act_space, hidden_sizes, activation_fn, device):
act_space_shape = get_space_shape(act_space) act_space_shape = get_space_shape(act_space)
if is_discrete_space(act_space): if is_discrete_space(act_space):
out_features = act_space_shape[-1] out_features = act_space_shape[-1]
@ -36,7 +14,7 @@ class Actor(TensorDictModule):
actor_module = nn.Sequential( actor_module = nn.Sequential(
MLP( MLP(
in_features=shared_module.out_features, in_features=get_space_shape(obs_space)[-1],
out_features=out_features, out_features=out_features,
num_cells=hidden_sizes, num_cells=hidden_sizes,
activation_class=getattr(nn, activation_fn), activation_class=getattr(nn, activation_fn),
@ -46,14 +24,14 @@ class Actor(TensorDictModule):
).to(device) ).to(device)
super().__init__( super().__init__(
module=actor_module, module=actor_module,
in_keys=["shared"], in_keys=["observation"],
out_keys=["loc", "scale"] if not is_discrete_space(act_space) else ["action_logits"], out_keys=["loc", "scale"] if not is_discrete_space(act_space) else ["action_logits"],
) )
class Critic(TensorDictModule): class Critic(TensorDictModule):
def __init__(self, shared_module, hidden_sizes, activation_fn, device): def __init__(self, obs_space, hidden_sizes, activation_fn, device):
critic_module = MLP( critic_module = MLP(
in_features=shared_module.out_features, in_features=get_space_shape(obs_space)[-1],
out_features=1, out_features=1,
num_cells=hidden_sizes, num_cells=hidden_sizes,
activation_class=getattr(nn, activation_fn), activation_class=getattr(nn, activation_fn),
@ -61,6 +39,6 @@ class Critic(TensorDictModule):
).to(device) ).to(device)
super().__init__( super().__init__(
module=critic_module, module=critic_module,
in_keys=["shared"], in_keys=["observation"],
out_keys=["state_value"], out_keys=["state_value"],
) )

View File

@ -1,9 +1,3 @@
try:
import gym
from gym.spaces import Discrete as GymDiscrete, MultiDiscrete as GymMultiDiscrete, MultiBinary as GymMultiBinary, Box as GymBox
except ImportError:
gym = None
import gymnasium import gymnasium
from gymnasium.spaces import Discrete as GymnasiumDiscrete, MultiDiscrete as GymnasiumMultiDiscrete, MultiBinary as GymnasiumMultiBinary, Box as GymnasiumBox from gymnasium.spaces import Discrete as GymnasiumDiscrete, MultiDiscrete as GymnasiumMultiDiscrete, MultiBinary as GymnasiumMultiBinary, Box as GymnasiumBox
from torchrl.data.tensor_specs import ( from torchrl.data.tensor_specs import (
@ -11,16 +5,26 @@ from torchrl.data.tensor_specs import (
BinaryDiscreteTensorSpec, BoundedTensorSpec, UnboundedContinuousTensorSpec BinaryDiscreteTensorSpec, BoundedTensorSpec, UnboundedContinuousTensorSpec
) )
try:
import gym
from gym.spaces import Discrete as GymDiscrete, MultiDiscrete as GymMultiDiscrete, MultiBinary as GymMultiBinary, Box as GymBox
gym_available = True
except ImportError:
gym_available = False
def is_discrete_space(action_space): def is_discrete_space(action_space):
discrete_types = ( discrete_types = (
GymDiscrete, GymMultiDiscrete, GymMultiBinary,
GymnasiumDiscrete, GymnasiumMultiDiscrete, GymnasiumMultiBinary, GymnasiumDiscrete, GymnasiumMultiDiscrete, GymnasiumMultiBinary,
DiscreteTensorSpec, OneHotDiscreteTensorSpec, MultiDiscreteTensorSpec, BinaryDiscreteTensorSpec DiscreteTensorSpec, OneHotDiscreteTensorSpec, MultiDiscreteTensorSpec, BinaryDiscreteTensorSpec
) )
continuous_types = ( continuous_types = (
GymBox, GymnasiumBox, BoundedTensorSpec, UnboundedContinuousTensorSpec GymnasiumBox, BoundedTensorSpec, UnboundedContinuousTensorSpec
) )
if gym_available:
discrete_types += (GymDiscrete, GymMultiDiscrete, GymMultiBinary)
continuous_types += (GymBox,)
if isinstance(action_space, discrete_types): if isinstance(action_space, discrete_types):
return True return True
elif isinstance(action_space, continuous_types): elif isinstance(action_space, continuous_types):
@ -29,7 +33,7 @@ def is_discrete_space(action_space):
raise ValueError(f"Unsupported action space type: {type(action_space)}") raise ValueError(f"Unsupported action space type: {type(action_space)}")
def get_space_shape(action_space): def get_space_shape(action_space):
if gym is not None: if gym_available:
discrete_types = (GymDiscrete, GymMultiDiscrete, GymMultiBinary) discrete_types = (GymDiscrete, GymMultiDiscrete, GymMultiBinary)
continuous_types = (GymBox,) continuous_types = (GymBox,)
else: else: