I hate debugging tensordict weirdness
This commit is contained in:
parent
a867a74138
commit
c7f5fcbf0f
@ -6,13 +6,12 @@ from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
|
|||||||
from torchrl.envs.libs.gym import GymWrapper
|
from torchrl.envs.libs.gym import GymWrapper
|
||||||
from torchrl.envs import ExplorationType, set_exploration_type
|
from torchrl.envs import ExplorationType, set_exploration_type
|
||||||
from torchrl.record import VideoRecorder
|
from torchrl.record import VideoRecorder
|
||||||
from abc import ABC, abstractmethod
|
from tensordict import LazyStackedTensorDict, TensorDict
|
||||||
|
from abc import ABC
|
||||||
|
|
||||||
class OnPolicy(ABC):
|
class OnPolicy(ABC):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
policy,
|
|
||||||
env_spec,
|
env_spec,
|
||||||
loggers,
|
loggers,
|
||||||
optimizers,
|
optimizers,
|
||||||
@ -21,19 +20,16 @@ class OnPolicy(ABC):
|
|||||||
batch_size,
|
batch_size,
|
||||||
n_epochs,
|
n_epochs,
|
||||||
gamma,
|
gamma,
|
||||||
gae_lambda,
|
|
||||||
total_timesteps,
|
total_timesteps,
|
||||||
eval_interval,
|
eval_interval,
|
||||||
eval_deterministic,
|
eval_deterministic,
|
||||||
entropy_coef,
|
entropy_coef,
|
||||||
critic_coef,
|
critic_coef,
|
||||||
normalize_advantage,
|
normalize_advantage,
|
||||||
clip_range=0.2,
|
|
||||||
device=None,
|
device=None,
|
||||||
eval_episodes=10,
|
eval_episodes=10,
|
||||||
env_spec_eval=None,
|
env_spec_eval=None,
|
||||||
):
|
):
|
||||||
self.policy = policy
|
|
||||||
self.env_spec = env_spec
|
self.env_spec = env_spec
|
||||||
self.env_spec_eval = env_spec_eval if env_spec_eval is not None else env_spec
|
self.env_spec_eval = env_spec_eval if env_spec_eval is not None else env_spec
|
||||||
self.loggers = loggers
|
self.loggers = loggers
|
||||||
@ -43,21 +39,19 @@ class OnPolicy(ABC):
|
|||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.n_epochs = n_epochs
|
self.n_epochs = n_epochs
|
||||||
self.gamma = gamma
|
self.gamma = gamma
|
||||||
self.gae_lambda = gae_lambda
|
|
||||||
self.total_timesteps = total_timesteps
|
self.total_timesteps = total_timesteps
|
||||||
self.eval_interval = eval_interval
|
self.eval_interval = eval_interval
|
||||||
self.eval_deterministic = eval_deterministic
|
self.eval_deterministic = eval_deterministic
|
||||||
self.entropy_coef = entropy_coef
|
self.entropy_coef = entropy_coef
|
||||||
self.critic_coef = critic_coef
|
self.critic_coef = critic_coef
|
||||||
self.normalize_advantage = normalize_advantage
|
self.normalize_advantage = normalize_advantage
|
||||||
self.clip_range = clip_range
|
|
||||||
self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
|
self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
self.eval_episodes = eval_episodes
|
self.eval_episodes = eval_episodes
|
||||||
|
|
||||||
# Create collector
|
# Create collector
|
||||||
self.collector = SyncDataCollector(
|
self.collector = SyncDataCollector(
|
||||||
create_env_fn=lambda: self.make_env(eval=False),
|
create_env_fn=lambda: self.make_env(eval=False),
|
||||||
policy=self.policy,
|
policy=self.actor,
|
||||||
frames_per_batch=self.n_steps,
|
frames_per_batch=self.n_steps,
|
||||||
total_frames=self.total_timesteps,
|
total_frames=self.total_timesteps,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
@ -78,13 +72,13 @@ class OnPolicy(ABC):
|
|||||||
env_spec = self.env_spec_eval if eval else self.env_spec
|
env_spec = self.env_spec_eval if eval else self.env_spec
|
||||||
if isinstance(env_spec, str):
|
if isinstance(env_spec, str):
|
||||||
env = gym.make(env_spec)
|
env = gym.make(env_spec)
|
||||||
env = GymWrapper(env)
|
env = GymWrapper(env).to(self.device)
|
||||||
elif callable(env_spec):
|
elif callable(env_spec):
|
||||||
env = env_spec()
|
env = env_spec()
|
||||||
if isinstance(env, gym.Env):
|
if isinstance(env, gym.Env):
|
||||||
env = GymWrapper(env)
|
env = GymWrapper(env).to(self.device)
|
||||||
elif isinstance(env, gym.Env):
|
elif isinstance(env, gym.Env):
|
||||||
env = GymWrapper(env)
|
env = GymWrapper(env).to(self.device)
|
||||||
else:
|
else:
|
||||||
raise ValueError("env_spec must be a string or a callable that returns an environment.")
|
raise ValueError("env_spec must be a string or a callable that returns an environment.")
|
||||||
return env
|
return env
|
||||||
@ -92,7 +86,8 @@ class OnPolicy(ABC):
|
|||||||
def train_step(self, batch):
|
def train_step(self, batch):
|
||||||
for optimizer in self.optimizers.values():
|
for optimizer in self.optimizers.values():
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss = self.loss_module(batch)
|
losses = self.loss_module(batch)
|
||||||
|
loss = losses['loss_objective'] + losses["loss_entropy"] + losses["loss_critic"]
|
||||||
loss.backward()
|
loss.backward()
|
||||||
for optimizer in self.optimizers.values():
|
for optimizer in self.optimizers.values():
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
@ -130,7 +125,7 @@ class OnPolicy(ABC):
|
|||||||
for _ in range(self.eval_episodes):
|
for _ in range(self.eval_episodes):
|
||||||
with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
|
with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
|
||||||
td_test = eval_env.rollout(
|
td_test = eval_env.rollout(
|
||||||
policy=self.policy,
|
policy=self.actor,
|
||||||
auto_reset=True,
|
auto_reset=True,
|
||||||
auto_cast_to_device=True,
|
auto_cast_to_device=True,
|
||||||
break_when_any_done=True,
|
break_when_any_done=True,
|
||||||
|
@ -3,13 +3,13 @@ from torchrl.modules import ActorValueOperator, ProbabilisticActor
|
|||||||
from torchrl.objectives import ClipPPOLoss
|
from torchrl.objectives import ClipPPOLoss
|
||||||
from torchrl.objectives.value.advantages import GAE
|
from torchrl.objectives.value.advantages import GAE
|
||||||
from fancy_rl.algos.on_policy import OnPolicy
|
from fancy_rl.algos.on_policy import OnPolicy
|
||||||
from fancy_rl.policy import Actor, Critic, SharedModule
|
from fancy_rl.policy import Actor, Critic
|
||||||
|
|
||||||
class PPO(OnPolicy):
|
class PPO(OnPolicy):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
env_spec,
|
env_spec,
|
||||||
loggers=None,
|
loggers=[],
|
||||||
actor_hidden_sizes=[64, 64],
|
actor_hidden_sizes=[64, 64],
|
||||||
critic_hidden_sizes=[64, 64],
|
critic_hidden_sizes=[64, 64],
|
||||||
actor_activation_fn="Tanh",
|
actor_activation_fn="Tanh",
|
||||||
@ -33,6 +33,7 @@ class PPO(OnPolicy):
|
|||||||
eval_episodes=10,
|
eval_episodes=10,
|
||||||
):
|
):
|
||||||
device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
self.device = device
|
||||||
|
|
||||||
# Initialize environment to get observation and action space sizes
|
# Initialize environment to get observation and action space sizes
|
||||||
self.env_spec = env_spec
|
self.env_spec = env_spec
|
||||||
@ -40,21 +41,10 @@ class PPO(OnPolicy):
|
|||||||
obs_space = env.observation_space
|
obs_space = env.observation_space
|
||||||
act_space = env.action_space
|
act_space = env.action_space
|
||||||
|
|
||||||
# Define the shared, actor, and critic modules
|
self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
|
||||||
self.shared_module = SharedModule(obs_space, shared_stem_sizes, actor_activation_fn, device)
|
actor_net = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device)
|
||||||
self.actor = Actor(self.shared_module, act_space, actor_hidden_sizes, actor_activation_fn, device)
|
self.actor = ProbabilisticActor(
|
||||||
self.critic = Critic(self.shared_module, critic_hidden_sizes, critic_activation_fn, device)
|
module=actor_net,
|
||||||
|
|
||||||
# Combine into an ActorValueOperator
|
|
||||||
self.ac_module = ActorValueOperator(
|
|
||||||
self.shared_module,
|
|
||||||
self.actor,
|
|
||||||
self.critic
|
|
||||||
)
|
|
||||||
|
|
||||||
# Define the policy as a ProbabilisticActor
|
|
||||||
policy = ProbabilisticActor(
|
|
||||||
module=self.ac_module.get_policy_operator(),
|
|
||||||
in_keys=["loc", "scale"],
|
in_keys=["loc", "scale"],
|
||||||
out_keys=["action"],
|
out_keys=["action"],
|
||||||
distribution_class=torch.distributions.Normal,
|
distribution_class=torch.distributions.Normal,
|
||||||
@ -67,7 +57,6 @@ class PPO(OnPolicy):
|
|||||||
}
|
}
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
policy=policy,
|
|
||||||
env_spec=env_spec,
|
env_spec=env_spec,
|
||||||
loggers=loggers,
|
loggers=loggers,
|
||||||
optimizers=optimizers,
|
optimizers=optimizers,
|
||||||
@ -76,14 +65,12 @@ class PPO(OnPolicy):
|
|||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
n_epochs=n_epochs,
|
n_epochs=n_epochs,
|
||||||
gamma=gamma,
|
gamma=gamma,
|
||||||
gae_lambda=gae_lambda,
|
|
||||||
total_timesteps=total_timesteps,
|
total_timesteps=total_timesteps,
|
||||||
eval_interval=eval_interval,
|
eval_interval=eval_interval,
|
||||||
eval_deterministic=eval_deterministic,
|
eval_deterministic=eval_deterministic,
|
||||||
entropy_coef=entropy_coef,
|
entropy_coef=entropy_coef,
|
||||||
critic_coef=critic_coef,
|
critic_coef=critic_coef,
|
||||||
normalize_advantage=normalize_advantage,
|
normalize_advantage=normalize_advantage,
|
||||||
clip_range=clip_range,
|
|
||||||
device=device,
|
device=device,
|
||||||
env_spec_eval=env_spec_eval,
|
env_spec_eval=env_spec_eval,
|
||||||
eval_episodes=eval_episodes,
|
eval_episodes=eval_episodes,
|
||||||
@ -91,7 +78,7 @@ class PPO(OnPolicy):
|
|||||||
|
|
||||||
self.adv_module = GAE(
|
self.adv_module = GAE(
|
||||||
gamma=self.gamma,
|
gamma=self.gamma,
|
||||||
lmbda=self.gae_lambda,
|
lmbda=gae_lambda,
|
||||||
value_network=self.critic,
|
value_network=self.critic,
|
||||||
average_gae=False,
|
average_gae=False,
|
||||||
)
|
)
|
||||||
@ -99,9 +86,9 @@ class PPO(OnPolicy):
|
|||||||
self.loss_module = ClipPPOLoss(
|
self.loss_module = ClipPPOLoss(
|
||||||
actor_network=self.actor,
|
actor_network=self.actor,
|
||||||
critic_network=self.critic,
|
critic_network=self.critic,
|
||||||
clip_epsilon=self.clip_range,
|
clip_epsilon=clip_range,
|
||||||
loss_critic_type='MSELoss',
|
loss_critic_type='l2',
|
||||||
entropy_coef=self.entropy_coef,
|
entropy_coef=self.entropy_coef,
|
||||||
critic_coef=self.critic_coef,
|
critic_coef=self.critic_coef,
|
||||||
normalize_advantage=self.normalize_advantage,
|
normalize_advantage=self.normalize_advantage,
|
||||||
)
|
)
|
||||||
|
@ -4,30 +4,8 @@ from torchrl.modules import MLP
|
|||||||
from tensordict.nn.distributions import NormalParamExtractor
|
from tensordict.nn.distributions import NormalParamExtractor
|
||||||
from fancy_rl.utils import is_discrete_space, get_space_shape
|
from fancy_rl.utils import is_discrete_space, get_space_shape
|
||||||
|
|
||||||
class SharedModule(TensorDictModule):
|
|
||||||
def __init__(self, obs_space, hidden_sizes, activation_fn, device):
|
|
||||||
if hidden_sizes:
|
|
||||||
shared_module = MLP(
|
|
||||||
in_features=get_space_shape(obs_space)[-1],
|
|
||||||
out_features=hidden_sizes[-1],
|
|
||||||
num_cells=hidden_sizes[:-1],
|
|
||||||
activation_class=getattr(nn, activation_fn),
|
|
||||||
device=device
|
|
||||||
)
|
|
||||||
out_features = hidden_sizes[-1]
|
|
||||||
else:
|
|
||||||
shared_module = nn.Identity()
|
|
||||||
out_features = get_space_shape(obs_space)[-1]
|
|
||||||
|
|
||||||
super().__init__(
|
|
||||||
module=shared_module,
|
|
||||||
in_keys=["observation"],
|
|
||||||
out_keys=["shared"],
|
|
||||||
)
|
|
||||||
self.out_features = out_features
|
|
||||||
|
|
||||||
class Actor(TensorDictModule):
|
class Actor(TensorDictModule):
|
||||||
def __init__(self, shared_module, act_space, hidden_sizes, activation_fn, device):
|
def __init__(self, obs_space, act_space, hidden_sizes, activation_fn, device):
|
||||||
act_space_shape = get_space_shape(act_space)
|
act_space_shape = get_space_shape(act_space)
|
||||||
if is_discrete_space(act_space):
|
if is_discrete_space(act_space):
|
||||||
out_features = act_space_shape[-1]
|
out_features = act_space_shape[-1]
|
||||||
@ -36,7 +14,7 @@ class Actor(TensorDictModule):
|
|||||||
|
|
||||||
actor_module = nn.Sequential(
|
actor_module = nn.Sequential(
|
||||||
MLP(
|
MLP(
|
||||||
in_features=shared_module.out_features,
|
in_features=get_space_shape(obs_space)[-1],
|
||||||
out_features=out_features,
|
out_features=out_features,
|
||||||
num_cells=hidden_sizes,
|
num_cells=hidden_sizes,
|
||||||
activation_class=getattr(nn, activation_fn),
|
activation_class=getattr(nn, activation_fn),
|
||||||
@ -46,14 +24,14 @@ class Actor(TensorDictModule):
|
|||||||
).to(device)
|
).to(device)
|
||||||
super().__init__(
|
super().__init__(
|
||||||
module=actor_module,
|
module=actor_module,
|
||||||
in_keys=["shared"],
|
in_keys=["observation"],
|
||||||
out_keys=["loc", "scale"] if not is_discrete_space(act_space) else ["action_logits"],
|
out_keys=["loc", "scale"] if not is_discrete_space(act_space) else ["action_logits"],
|
||||||
)
|
)
|
||||||
|
|
||||||
class Critic(TensorDictModule):
|
class Critic(TensorDictModule):
|
||||||
def __init__(self, shared_module, hidden_sizes, activation_fn, device):
|
def __init__(self, obs_space, hidden_sizes, activation_fn, device):
|
||||||
critic_module = MLP(
|
critic_module = MLP(
|
||||||
in_features=shared_module.out_features,
|
in_features=get_space_shape(obs_space)[-1],
|
||||||
out_features=1,
|
out_features=1,
|
||||||
num_cells=hidden_sizes,
|
num_cells=hidden_sizes,
|
||||||
activation_class=getattr(nn, activation_fn),
|
activation_class=getattr(nn, activation_fn),
|
||||||
@ -61,6 +39,6 @@ class Critic(TensorDictModule):
|
|||||||
).to(device)
|
).to(device)
|
||||||
super().__init__(
|
super().__init__(
|
||||||
module=critic_module,
|
module=critic_module,
|
||||||
in_keys=["shared"],
|
in_keys=["observation"],
|
||||||
out_keys=["state_value"],
|
out_keys=["state_value"],
|
||||||
)
|
)
|
Loading…
Reference in New Issue
Block a user