I hate debugging tensordict weirdness

This commit is contained in:
Dominik Moritz Roth 2024-06-02 13:56:54 +02:00
parent a867a74138
commit c7f5fcbf0f
3 changed files with 27 additions and 67 deletions

View File

@ -6,13 +6,12 @@ from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.envs.libs.gym import GymWrapper
from torchrl.envs import ExplorationType, set_exploration_type
from torchrl.record import VideoRecorder
from abc import ABC, abstractmethod
from tensordict import LazyStackedTensorDict, TensorDict
from abc import ABC
class OnPolicy(ABC):
def __init__(
self,
policy,
env_spec,
loggers,
optimizers,
@ -21,19 +20,16 @@ class OnPolicy(ABC):
batch_size,
n_epochs,
gamma,
gae_lambda,
total_timesteps,
eval_interval,
eval_deterministic,
entropy_coef,
critic_coef,
normalize_advantage,
clip_range=0.2,
device=None,
eval_episodes=10,
env_spec_eval=None,
):
self.policy = policy
self.env_spec = env_spec
self.env_spec_eval = env_spec_eval if env_spec_eval is not None else env_spec
self.loggers = loggers
@ -43,21 +39,19 @@ class OnPolicy(ABC):
self.batch_size = batch_size
self.n_epochs = n_epochs
self.gamma = gamma
self.gae_lambda = gae_lambda
self.total_timesteps = total_timesteps
self.eval_interval = eval_interval
self.eval_deterministic = eval_deterministic
self.entropy_coef = entropy_coef
self.critic_coef = critic_coef
self.normalize_advantage = normalize_advantage
self.clip_range = clip_range
self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
self.eval_episodes = eval_episodes
# Create collector
self.collector = SyncDataCollector(
create_env_fn=lambda: self.make_env(eval=False),
policy=self.policy,
policy=self.actor,
frames_per_batch=self.n_steps,
total_frames=self.total_timesteps,
device=self.device,
@ -78,13 +72,13 @@ class OnPolicy(ABC):
env_spec = self.env_spec_eval if eval else self.env_spec
if isinstance(env_spec, str):
env = gym.make(env_spec)
env = GymWrapper(env)
env = GymWrapper(env).to(self.device)
elif callable(env_spec):
env = env_spec()
if isinstance(env, gym.Env):
env = GymWrapper(env)
env = GymWrapper(env).to(self.device)
elif isinstance(env, gym.Env):
env = GymWrapper(env)
env = GymWrapper(env).to(self.device)
else:
raise ValueError("env_spec must be a string or a callable that returns an environment.")
return env
@ -92,7 +86,8 @@ class OnPolicy(ABC):
def train_step(self, batch):
for optimizer in self.optimizers.values():
optimizer.zero_grad()
loss = self.loss_module(batch)
losses = self.loss_module(batch)
loss = losses['loss_objective'] + losses["loss_entropy"] + losses["loss_critic"]
loss.backward()
for optimizer in self.optimizers.values():
optimizer.step()
@ -130,7 +125,7 @@ class OnPolicy(ABC):
for _ in range(self.eval_episodes):
with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
td_test = eval_env.rollout(
policy=self.policy,
policy=self.actor,
auto_reset=True,
auto_cast_to_device=True,
break_when_any_done=True,

View File

@ -3,13 +3,13 @@ from torchrl.modules import ActorValueOperator, ProbabilisticActor
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value.advantages import GAE
from fancy_rl.algos.on_policy import OnPolicy
from fancy_rl.policy import Actor, Critic, SharedModule
from fancy_rl.policy import Actor, Critic
class PPO(OnPolicy):
def __init__(
self,
env_spec,
loggers=None,
loggers=[],
actor_hidden_sizes=[64, 64],
critic_hidden_sizes=[64, 64],
actor_activation_fn="Tanh",
@ -33,6 +33,7 @@ class PPO(OnPolicy):
eval_episodes=10,
):
device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.device = device
# Initialize environment to get observation and action space sizes
self.env_spec = env_spec
@ -40,21 +41,10 @@ class PPO(OnPolicy):
obs_space = env.observation_space
act_space = env.action_space
# Define the shared, actor, and critic modules
self.shared_module = SharedModule(obs_space, shared_stem_sizes, actor_activation_fn, device)
self.actor = Actor(self.shared_module, act_space, actor_hidden_sizes, actor_activation_fn, device)
self.critic = Critic(self.shared_module, critic_hidden_sizes, critic_activation_fn, device)
# Combine into an ActorValueOperator
self.ac_module = ActorValueOperator(
self.shared_module,
self.actor,
self.critic
)
# Define the policy as a ProbabilisticActor
policy = ProbabilisticActor(
module=self.ac_module.get_policy_operator(),
self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
actor_net = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device)
self.actor = ProbabilisticActor(
module=actor_net,
in_keys=["loc", "scale"],
out_keys=["action"],
distribution_class=torch.distributions.Normal,
@ -67,7 +57,6 @@ class PPO(OnPolicy):
}
super().__init__(
policy=policy,
env_spec=env_spec,
loggers=loggers,
optimizers=optimizers,
@ -76,14 +65,12 @@ class PPO(OnPolicy):
batch_size=batch_size,
n_epochs=n_epochs,
gamma=gamma,
gae_lambda=gae_lambda,
total_timesteps=total_timesteps,
eval_interval=eval_interval,
eval_deterministic=eval_deterministic,
entropy_coef=entropy_coef,
critic_coef=critic_coef,
normalize_advantage=normalize_advantage,
clip_range=clip_range,
device=device,
env_spec_eval=env_spec_eval,
eval_episodes=eval_episodes,
@ -91,7 +78,7 @@ class PPO(OnPolicy):
self.adv_module = GAE(
gamma=self.gamma,
lmbda=self.gae_lambda,
lmbda=gae_lambda,
value_network=self.critic,
average_gae=False,
)
@ -99,8 +86,8 @@ class PPO(OnPolicy):
self.loss_module = ClipPPOLoss(
actor_network=self.actor,
critic_network=self.critic,
clip_epsilon=self.clip_range,
loss_critic_type='MSELoss',
clip_epsilon=clip_range,
loss_critic_type='l2',
entropy_coef=self.entropy_coef,
critic_coef=self.critic_coef,
normalize_advantage=self.normalize_advantage,

View File

@ -4,30 +4,8 @@ from torchrl.modules import MLP
from tensordict.nn.distributions import NormalParamExtractor
from fancy_rl.utils import is_discrete_space, get_space_shape
class SharedModule(TensorDictModule):
def __init__(self, obs_space, hidden_sizes, activation_fn, device):
if hidden_sizes:
shared_module = MLP(
in_features=get_space_shape(obs_space)[-1],
out_features=hidden_sizes[-1],
num_cells=hidden_sizes[:-1],
activation_class=getattr(nn, activation_fn),
device=device
)
out_features = hidden_sizes[-1]
else:
shared_module = nn.Identity()
out_features = get_space_shape(obs_space)[-1]
super().__init__(
module=shared_module,
in_keys=["observation"],
out_keys=["shared"],
)
self.out_features = out_features
class Actor(TensorDictModule):
def __init__(self, shared_module, act_space, hidden_sizes, activation_fn, device):
def __init__(self, obs_space, act_space, hidden_sizes, activation_fn, device):
act_space_shape = get_space_shape(act_space)
if is_discrete_space(act_space):
out_features = act_space_shape[-1]
@ -36,7 +14,7 @@ class Actor(TensorDictModule):
actor_module = nn.Sequential(
MLP(
in_features=shared_module.out_features,
in_features=get_space_shape(obs_space)[-1],
out_features=out_features,
num_cells=hidden_sizes,
activation_class=getattr(nn, activation_fn),
@ -46,14 +24,14 @@ class Actor(TensorDictModule):
).to(device)
super().__init__(
module=actor_module,
in_keys=["shared"],
in_keys=["observation"],
out_keys=["loc", "scale"] if not is_discrete_space(act_space) else ["action_logits"],
)
class Critic(TensorDictModule):
def __init__(self, shared_module, hidden_sizes, activation_fn, device):
def __init__(self, obs_space, hidden_sizes, activation_fn, device):
critic_module = MLP(
in_features=shared_module.out_features,
in_features=get_space_shape(obs_space)[-1],
out_features=1,
num_cells=hidden_sizes,
activation_class=getattr(nn, activation_fn),
@ -61,6 +39,6 @@ class Critic(TensorDictModule):
).to(device)
super().__init__(
module=critic_module,
in_keys=["shared"],
in_keys=["observation"],
out_keys=["state_value"],
)