I hate debugging tensordict weirdness

This commit is contained in:
Dominik Moritz Roth 2024-06-02 13:56:54 +02:00
parent a867a74138
commit c7f5fcbf0f
3 changed files with 27 additions and 67 deletions

View File

@ -6,13 +6,12 @@ from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.envs.libs.gym import GymWrapper from torchrl.envs.libs.gym import GymWrapper
from torchrl.envs import ExplorationType, set_exploration_type from torchrl.envs import ExplorationType, set_exploration_type
from torchrl.record import VideoRecorder from torchrl.record import VideoRecorder
from abc import ABC, abstractmethod from tensordict import LazyStackedTensorDict, TensorDict
from abc import ABC
class OnPolicy(ABC): class OnPolicy(ABC):
def __init__( def __init__(
self, self,
policy,
env_spec, env_spec,
loggers, loggers,
optimizers, optimizers,
@ -21,19 +20,16 @@ class OnPolicy(ABC):
batch_size, batch_size,
n_epochs, n_epochs,
gamma, gamma,
gae_lambda,
total_timesteps, total_timesteps,
eval_interval, eval_interval,
eval_deterministic, eval_deterministic,
entropy_coef, entropy_coef,
critic_coef, critic_coef,
normalize_advantage, normalize_advantage,
clip_range=0.2,
device=None, device=None,
eval_episodes=10, eval_episodes=10,
env_spec_eval=None, env_spec_eval=None,
): ):
self.policy = policy
self.env_spec = env_spec self.env_spec = env_spec
self.env_spec_eval = env_spec_eval if env_spec_eval is not None else env_spec self.env_spec_eval = env_spec_eval if env_spec_eval is not None else env_spec
self.loggers = loggers self.loggers = loggers
@ -43,21 +39,19 @@ class OnPolicy(ABC):
self.batch_size = batch_size self.batch_size = batch_size
self.n_epochs = n_epochs self.n_epochs = n_epochs
self.gamma = gamma self.gamma = gamma
self.gae_lambda = gae_lambda
self.total_timesteps = total_timesteps self.total_timesteps = total_timesteps
self.eval_interval = eval_interval self.eval_interval = eval_interval
self.eval_deterministic = eval_deterministic self.eval_deterministic = eval_deterministic
self.entropy_coef = entropy_coef self.entropy_coef = entropy_coef
self.critic_coef = critic_coef self.critic_coef = critic_coef
self.normalize_advantage = normalize_advantage self.normalize_advantage = normalize_advantage
self.clip_range = clip_range
self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu") self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
self.eval_episodes = eval_episodes self.eval_episodes = eval_episodes
# Create collector # Create collector
self.collector = SyncDataCollector( self.collector = SyncDataCollector(
create_env_fn=lambda: self.make_env(eval=False), create_env_fn=lambda: self.make_env(eval=False),
policy=self.policy, policy=self.actor,
frames_per_batch=self.n_steps, frames_per_batch=self.n_steps,
total_frames=self.total_timesteps, total_frames=self.total_timesteps,
device=self.device, device=self.device,
@ -78,13 +72,13 @@ class OnPolicy(ABC):
env_spec = self.env_spec_eval if eval else self.env_spec env_spec = self.env_spec_eval if eval else self.env_spec
if isinstance(env_spec, str): if isinstance(env_spec, str):
env = gym.make(env_spec) env = gym.make(env_spec)
env = GymWrapper(env) env = GymWrapper(env).to(self.device)
elif callable(env_spec): elif callable(env_spec):
env = env_spec() env = env_spec()
if isinstance(env, gym.Env): if isinstance(env, gym.Env):
env = GymWrapper(env) env = GymWrapper(env).to(self.device)
elif isinstance(env, gym.Env): elif isinstance(env, gym.Env):
env = GymWrapper(env) env = GymWrapper(env).to(self.device)
else: else:
raise ValueError("env_spec must be a string or a callable that returns an environment.") raise ValueError("env_spec must be a string or a callable that returns an environment.")
return env return env
@ -92,7 +86,8 @@ class OnPolicy(ABC):
def train_step(self, batch): def train_step(self, batch):
for optimizer in self.optimizers.values(): for optimizer in self.optimizers.values():
optimizer.zero_grad() optimizer.zero_grad()
loss = self.loss_module(batch) losses = self.loss_module(batch)
loss = losses['loss_objective'] + losses["loss_entropy"] + losses["loss_critic"]
loss.backward() loss.backward()
for optimizer in self.optimizers.values(): for optimizer in self.optimizers.values():
optimizer.step() optimizer.step()
@ -130,7 +125,7 @@ class OnPolicy(ABC):
for _ in range(self.eval_episodes): for _ in range(self.eval_episodes):
with torch.no_grad(), set_exploration_type(ExplorationType.MODE): with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
td_test = eval_env.rollout( td_test = eval_env.rollout(
policy=self.policy, policy=self.actor,
auto_reset=True, auto_reset=True,
auto_cast_to_device=True, auto_cast_to_device=True,
break_when_any_done=True, break_when_any_done=True,

View File

@ -3,13 +3,13 @@ from torchrl.modules import ActorValueOperator, ProbabilisticActor
from torchrl.objectives import ClipPPOLoss from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value.advantages import GAE from torchrl.objectives.value.advantages import GAE
from fancy_rl.algos.on_policy import OnPolicy from fancy_rl.algos.on_policy import OnPolicy
from fancy_rl.policy import Actor, Critic, SharedModule from fancy_rl.policy import Actor, Critic
class PPO(OnPolicy): class PPO(OnPolicy):
def __init__( def __init__(
self, self,
env_spec, env_spec,
loggers=None, loggers=[],
actor_hidden_sizes=[64, 64], actor_hidden_sizes=[64, 64],
critic_hidden_sizes=[64, 64], critic_hidden_sizes=[64, 64],
actor_activation_fn="Tanh", actor_activation_fn="Tanh",
@ -33,6 +33,7 @@ class PPO(OnPolicy):
eval_episodes=10, eval_episodes=10,
): ):
device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.device = device
# Initialize environment to get observation and action space sizes # Initialize environment to get observation and action space sizes
self.env_spec = env_spec self.env_spec = env_spec
@ -40,21 +41,10 @@ class PPO(OnPolicy):
obs_space = env.observation_space obs_space = env.observation_space
act_space = env.action_space act_space = env.action_space
# Define the shared, actor, and critic modules self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
self.shared_module = SharedModule(obs_space, shared_stem_sizes, actor_activation_fn, device) actor_net = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device)
self.actor = Actor(self.shared_module, act_space, actor_hidden_sizes, actor_activation_fn, device) self.actor = ProbabilisticActor(
self.critic = Critic(self.shared_module, critic_hidden_sizes, critic_activation_fn, device) module=actor_net,
# Combine into an ActorValueOperator
self.ac_module = ActorValueOperator(
self.shared_module,
self.actor,
self.critic
)
# Define the policy as a ProbabilisticActor
policy = ProbabilisticActor(
module=self.ac_module.get_policy_operator(),
in_keys=["loc", "scale"], in_keys=["loc", "scale"],
out_keys=["action"], out_keys=["action"],
distribution_class=torch.distributions.Normal, distribution_class=torch.distributions.Normal,
@ -67,7 +57,6 @@ class PPO(OnPolicy):
} }
super().__init__( super().__init__(
policy=policy,
env_spec=env_spec, env_spec=env_spec,
loggers=loggers, loggers=loggers,
optimizers=optimizers, optimizers=optimizers,
@ -76,14 +65,12 @@ class PPO(OnPolicy):
batch_size=batch_size, batch_size=batch_size,
n_epochs=n_epochs, n_epochs=n_epochs,
gamma=gamma, gamma=gamma,
gae_lambda=gae_lambda,
total_timesteps=total_timesteps, total_timesteps=total_timesteps,
eval_interval=eval_interval, eval_interval=eval_interval,
eval_deterministic=eval_deterministic, eval_deterministic=eval_deterministic,
entropy_coef=entropy_coef, entropy_coef=entropy_coef,
critic_coef=critic_coef, critic_coef=critic_coef,
normalize_advantage=normalize_advantage, normalize_advantage=normalize_advantage,
clip_range=clip_range,
device=device, device=device,
env_spec_eval=env_spec_eval, env_spec_eval=env_spec_eval,
eval_episodes=eval_episodes, eval_episodes=eval_episodes,
@ -91,7 +78,7 @@ class PPO(OnPolicy):
self.adv_module = GAE( self.adv_module = GAE(
gamma=self.gamma, gamma=self.gamma,
lmbda=self.gae_lambda, lmbda=gae_lambda,
value_network=self.critic, value_network=self.critic,
average_gae=False, average_gae=False,
) )
@ -99,9 +86,9 @@ class PPO(OnPolicy):
self.loss_module = ClipPPOLoss( self.loss_module = ClipPPOLoss(
actor_network=self.actor, actor_network=self.actor,
critic_network=self.critic, critic_network=self.critic,
clip_epsilon=self.clip_range, clip_epsilon=clip_range,
loss_critic_type='MSELoss', loss_critic_type='l2',
entropy_coef=self.entropy_coef, entropy_coef=self.entropy_coef,
critic_coef=self.critic_coef, critic_coef=self.critic_coef,
normalize_advantage=self.normalize_advantage, normalize_advantage=self.normalize_advantage,
) )

View File

@ -4,30 +4,8 @@ from torchrl.modules import MLP
from tensordict.nn.distributions import NormalParamExtractor from tensordict.nn.distributions import NormalParamExtractor
from fancy_rl.utils import is_discrete_space, get_space_shape from fancy_rl.utils import is_discrete_space, get_space_shape
class SharedModule(TensorDictModule):
def __init__(self, obs_space, hidden_sizes, activation_fn, device):
if hidden_sizes:
shared_module = MLP(
in_features=get_space_shape(obs_space)[-1],
out_features=hidden_sizes[-1],
num_cells=hidden_sizes[:-1],
activation_class=getattr(nn, activation_fn),
device=device
)
out_features = hidden_sizes[-1]
else:
shared_module = nn.Identity()
out_features = get_space_shape(obs_space)[-1]
super().__init__(
module=shared_module,
in_keys=["observation"],
out_keys=["shared"],
)
self.out_features = out_features
class Actor(TensorDictModule): class Actor(TensorDictModule):
def __init__(self, shared_module, act_space, hidden_sizes, activation_fn, device): def __init__(self, obs_space, act_space, hidden_sizes, activation_fn, device):
act_space_shape = get_space_shape(act_space) act_space_shape = get_space_shape(act_space)
if is_discrete_space(act_space): if is_discrete_space(act_space):
out_features = act_space_shape[-1] out_features = act_space_shape[-1]
@ -36,7 +14,7 @@ class Actor(TensorDictModule):
actor_module = nn.Sequential( actor_module = nn.Sequential(
MLP( MLP(
in_features=shared_module.out_features, in_features=get_space_shape(obs_space)[-1],
out_features=out_features, out_features=out_features,
num_cells=hidden_sizes, num_cells=hidden_sizes,
activation_class=getattr(nn, activation_fn), activation_class=getattr(nn, activation_fn),
@ -46,14 +24,14 @@ class Actor(TensorDictModule):
).to(device) ).to(device)
super().__init__( super().__init__(
module=actor_module, module=actor_module,
in_keys=["shared"], in_keys=["observation"],
out_keys=["loc", "scale"] if not is_discrete_space(act_space) else ["action_logits"], out_keys=["loc", "scale"] if not is_discrete_space(act_space) else ["action_logits"],
) )
class Critic(TensorDictModule): class Critic(TensorDictModule):
def __init__(self, shared_module, hidden_sizes, activation_fn, device): def __init__(self, obs_space, hidden_sizes, activation_fn, device):
critic_module = MLP( critic_module = MLP(
in_features=shared_module.out_features, in_features=get_space_shape(obs_space)[-1],
out_features=1, out_features=1,
num_cells=hidden_sizes, num_cells=hidden_sizes,
activation_class=getattr(nn, activation_fn), activation_class=getattr(nn, activation_fn),
@ -61,6 +39,6 @@ class Critic(TensorDictModule):
).to(device) ).to(device)
super().__init__( super().__init__(
module=critic_module, module=critic_module,
in_keys=["shared"], in_keys=["observation"],
out_keys=["state_value"], out_keys=["state_value"],
) )