Rework algo impls

This commit is contained in:
Dominik Moritz Roth 2024-10-21 15:23:39 +02:00
parent 651ef1522f
commit 0c6e58634f
4 changed files with 105 additions and 50 deletions

View File

@ -3,6 +3,7 @@ import gymnasium as gym
from torchrl.envs.libs.gym import GymWrapper from torchrl.envs.libs.gym import GymWrapper
from torchrl.record import VideoRecorder from torchrl.record import VideoRecorder
from abc import ABC from abc import ABC
from tensordict import TensorDict
from fancy_rl.loggers import TerminalLogger from fancy_rl.loggers import TerminalLogger
@ -53,12 +54,11 @@ class Algo(ABC):
env = GymWrapper(env).to(self.device) env = GymWrapper(env).to(self.device)
elif callable(env_spec): elif callable(env_spec):
env = env_spec() env = env_spec()
if isinstance(env, gym.Env): if not (isinstance(env, gym.Env) or isinstance(env, gym.core.Wrapper)):
env = GymWrapper(env).to(self.device) raise ValueError("env_spec must be a string or a callable that returns an environment. Was a callable that returned a {}".format(type(env)))
elif isinstance(env, gym.Env):
env = GymWrapper(env).to(self.device) env = GymWrapper(env).to(self.device)
else: else:
raise ValueError("env_spec must be a string or a callable that returns an environment.") raise ValueError("env_spec must be a string or a callable that returns an environment. Was a {}".format(type(env_spec)))
return env return env
def train_step(self, batch): def train_step(self, batch):
@ -70,6 +70,20 @@ class Algo(ABC):
def evaluate(self, epoch): def evaluate(self, epoch):
raise NotImplementedError("evaluate method must be implemented in subclass.") raise NotImplementedError("evaluate method must be implemented in subclass.")
def dump_video(module): def predict(
if isinstance(module, VideoRecorder): self,
module.dump() observation,
state=None,
deterministic=False
):
with torch.no_grad():
obs_tensor = torch.as_tensor(observation, device=self.device).unsqueeze(0)
td = TensorDict({"observation": obs_tensor}, batch_size=[1])
action_td = self.prob_actor(td)
action = action_td["action"]
# We're not using recurrent policies, so we'll always return None for the state
next_state = None
return action.squeeze(0).cpu().numpy(), next_state

View File

@ -55,7 +55,7 @@ class OnPolicy(Algo):
# Create collector # Create collector
self.collector = SyncDataCollector( self.collector = SyncDataCollector(
create_env_fn=lambda: self.make_env(eval=False), create_env_fn=lambda: self.make_env(eval=False),
policy=self.actor, policy=self.prob_actor,
frames_per_batch=self.n_steps, frames_per_batch=self.n_steps,
total_frames=self.total_timesteps, total_frames=self.total_timesteps,
device=self.device, device=self.device,

View File

@ -4,7 +4,7 @@ from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value.advantages import GAE from torchrl.objectives.value.advantages import GAE
from fancy_rl.algos.on_policy import OnPolicy from fancy_rl.algos.on_policy import OnPolicy
from fancy_rl.policy import Actor, Critic from fancy_rl.policy import Actor, Critic
from fancy_rl.projections import get_projection # Updated import from fancy_rl.utils import is_discrete_space
class PPO(OnPolicy): class PPO(OnPolicy):
def __init__( def __init__(
@ -31,7 +31,10 @@ class PPO(OnPolicy):
device=None, device=None,
env_spec_eval=None, env_spec_eval=None,
eval_episodes=10, eval_episodes=10,
full_covariance=False,
): ):
self.clip_range = clip_range
device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.device = device self.device = device
@ -41,15 +44,29 @@ class PPO(OnPolicy):
obs_space = env.observation_space obs_space = env.observation_space
act_space = env.action_space act_space = env.action_space
self.discrete = is_discrete_space(act_space)
self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device) self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
actor_net = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device) self.actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
self.actor = ProbabilisticActor(
module=actor_net, if self.discrete:
in_keys=["loc", "scale"], distribution_class = torch.distributions.Categorical
out_keys=["action"], distribution_kwargs = {"logits": "action_logits"}
distribution_class=torch.distributions.Normal, else:
return_log_prob=True if full_covariance:
) distribution_class = torch.distributions.MultivariateNormal
in_keys = ["loc", "scale_tril"]
else:
distribution_class = torch.distributions.Normal
in_keys = ["loc", "scale"]
self.prob_actor = ProbabilisticActor(
module=self.actor,
distribution_class=distribution_class,
return_log_prob=True,
in_keys=in_keys,
out_keys=["action"]
)
optimizers = { optimizers = {
"actor": torch.optim.Adam(self.actor.parameters(), lr=learning_rate), "actor": torch.optim.Adam(self.actor.parameters(), lr=learning_rate),
@ -86,9 +103,9 @@ class PPO(OnPolicy):
self.loss_module = ClipPPOLoss( self.loss_module = ClipPPOLoss(
actor_network=self.actor, actor_network=self.actor,
critic_network=self.critic, critic_network=self.critic,
clip_epsilon=clip_range, clip_epsilon=self.clip_range,
loss_critic_type='l2', loss_critic_type='l2',
entropy_coef=self.entropy_coef, entropy_coef=self.entropy_coef,
critic_coef=self.critic_coef, critic_coef=self.critic_coef,
normalize_advantage=self.normalize_advantage, normalize_advantage=self.normalize_advantage,
) )

View File

@ -10,7 +10,36 @@ from fancy_rl.algos.on_policy import OnPolicy
from fancy_rl.policy import Actor, Critic from fancy_rl.policy import Actor, Critic
from fancy_rl.projections import get_projection, BaseProjection from fancy_rl.projections import get_projection, BaseProjection
from fancy_rl.objectives import TRPLLoss from fancy_rl.objectives import TRPLLoss
from fancy_rl.utils import is_discrete_space
from copy import deepcopy from copy import deepcopy
from tensordict.nn import TensorDictModule
from tensordict import TensorDict
class ProjectedActor(TensorDictModule):
def __init__(self, raw_actor, old_actor, projection):
combined_module = self.CombinedModule(raw_actor, old_actor, projection)
super().__init__(
module=combined_module,
in_keys=raw_actor.in_keys,
out_keys=raw_actor.out_keys
)
self.raw_actor = raw_actor
self.old_actor = old_actor
self.projection = projection
class CombinedModule(nn.Module):
def __init__(self, raw_actor, old_actor, projection):
super().__init__()
self.raw_actor = raw_actor
self.old_actor = old_actor
self.projection = projection
def forward(self, tensordict):
raw_params = self.raw_actor(tensordict)
old_params = self.old_actor(tensordict)
combined_params = TensorDict({**raw_params, **{f"old_{key}": value for key, value in old_params.items()}}, batch_size=tensordict.batch_size)
projected_params = self.projection(combined_params)
return projected_params
class TRPL(OnPolicy): class TRPL(OnPolicy):
def __init__( def __init__(
@ -40,6 +69,7 @@ class TRPL(OnPolicy):
device=None, device=None,
env_spec_eval=None, env_spec_eval=None,
eval_episodes=10, eval_episodes=10,
full_covariance=False,
): ):
device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.device = device self.device = device
@ -50,8 +80,11 @@ class TRPL(OnPolicy):
obs_space = env.observation_space obs_space = env.observation_space
act_space = env.action_space act_space = env.action_space
assert not is_discrete_space(act_space), "TRPL does not support discrete action spaces"
self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device) self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
actor_net = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device) self.raw_actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
self.old_actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
# Handle projection_class # Handle projection_class
if isinstance(projection_class, str): if isinstance(projection_class, str):
@ -60,20 +93,27 @@ class TRPL(OnPolicy):
raise ValueError("projection_class must be a string or a subclass of BaseProjection") raise ValueError("projection_class must be a string or a subclass of BaseProjection")
self.projection = projection_class( self.projection = projection_class(
in_keys=["loc", "scale"], in_keys=["loc", "scale_tril", "old_loc", "old_scale_tril"] if full_covariance else ["loc", "scale", "old_loc", "old_scale"],
out_keys=["loc", "scale"], out_keys=["loc", "scale_tril"] if full_covariance else ["loc", "scale"],
trust_region_bound_mean=trust_region_bound_mean, mean_bound=trust_region_bound_mean,
trust_region_bound_cov=trust_region_bound_cov cov_bound=trust_region_bound_cov
) )
self.actor = ProbabilisticActor( self.actor = ProjectedActor(self.raw_actor, self.old_actor, self.projection)
module=actor_net,
in_keys=["observation"], if full_covariance:
out_keys=["loc", "scale"], distribution_class = torch.distributions.MultivariateNormal
distribution_class=torch.distributions.Normal, distribution_kwargs = {"loc": "loc", "scale_tril": "scale_tril"}
return_log_prob=True else:
distribution_class = torch.distributions.Normal
distribution_kwargs = {"loc": "loc", "scale": "scale"}
self.prob_actor = ProbabilisticActor(
module=self.actor,
distribution_class=distribution_class,
return_log_prob=True,
in_keys=distribution_kwargs,
) )
self.old_actor = deepcopy(self.actor)
self.trust_region_coef = trust_region_coef self.trust_region_coef = trust_region_coef
self.loss_module = TRPLLoss( self.loss_module = TRPLLoss(
@ -88,7 +128,7 @@ class TRPL(OnPolicy):
) )
optimizers = { optimizers = {
"actor": torch.optim.Adam(self.actor.parameters(), lr=learning_rate), "actor": torch.optim.Adam(self.raw_actor.parameters(), lr=learning_rate),
"critic": torch.optim.Adam(self.critic.parameters(), lr=learning_rate) "critic": torch.optim.Adam(self.critic.parameters(), lr=learning_rate)
} }
@ -119,23 +159,7 @@ class TRPL(OnPolicy):
) )
def update_old_policy(self): def update_old_policy(self):
self.old_actor.load_state_dict(self.actor.state_dict()) self.old_actor.load_state_dict(self.raw_actor.state_dict())
def project_policy(self, obs):
with torch.no_grad():
old_dist = self.old_actor(obs)
new_dist = self.actor(obs)
projected_params = self.projection.project(new_dist, old_dist)
return projected_params
def pre_update(self, tensordict):
obs = tensordict["observation"]
projected_dist = self.project_policy(obs)
# Update tensordict with projected distribution parameters
tensordict["projected_loc"] = projected_dist[0]
tensordict["projected_scale"] = projected_dist[1]
return tensordict
def post_update(self): def post_update(self):
self.update_old_policy() self.update_old_policy()