Rework algo impls

This commit is contained in:
Dominik Moritz Roth 2024-10-21 15:23:39 +02:00
parent 651ef1522f
commit 0c6e58634f
4 changed files with 105 additions and 50 deletions

View File

@ -3,6 +3,7 @@ import gymnasium as gym
from torchrl.envs.libs.gym import GymWrapper
from torchrl.record import VideoRecorder
from abc import ABC
from tensordict import TensorDict
from fancy_rl.loggers import TerminalLogger
@ -53,12 +54,11 @@ class Algo(ABC):
env = GymWrapper(env).to(self.device)
elif callable(env_spec):
env = env_spec()
if isinstance(env, gym.Env):
env = GymWrapper(env).to(self.device)
elif isinstance(env, gym.Env):
if not (isinstance(env, gym.Env) or isinstance(env, gym.core.Wrapper)):
raise ValueError("env_spec must be a string or a callable that returns an environment. Was a callable that returned a {}".format(type(env)))
env = GymWrapper(env).to(self.device)
else:
raise ValueError("env_spec must be a string or a callable that returns an environment.")
raise ValueError("env_spec must be a string or a callable that returns an environment. Was a {}".format(type(env_spec)))
return env
def train_step(self, batch):
@ -70,6 +70,20 @@ class Algo(ABC):
def evaluate(self, epoch):
raise NotImplementedError("evaluate method must be implemented in subclass.")
def dump_video(module):
if isinstance(module, VideoRecorder):
module.dump()
def predict(
self,
observation,
state=None,
deterministic=False
):
with torch.no_grad():
obs_tensor = torch.as_tensor(observation, device=self.device).unsqueeze(0)
td = TensorDict({"observation": obs_tensor}, batch_size=[1])
action_td = self.prob_actor(td)
action = action_td["action"]
# We're not using recurrent policies, so we'll always return None for the state
next_state = None
return action.squeeze(0).cpu().numpy(), next_state

View File

@ -55,7 +55,7 @@ class OnPolicy(Algo):
# Create collector
self.collector = SyncDataCollector(
create_env_fn=lambda: self.make_env(eval=False),
policy=self.actor,
policy=self.prob_actor,
frames_per_batch=self.n_steps,
total_frames=self.total_timesteps,
device=self.device,

View File

@ -4,7 +4,7 @@ from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value.advantages import GAE
from fancy_rl.algos.on_policy import OnPolicy
from fancy_rl.policy import Actor, Critic
from fancy_rl.projections import get_projection # Updated import
from fancy_rl.utils import is_discrete_space
class PPO(OnPolicy):
def __init__(
@ -31,7 +31,10 @@ class PPO(OnPolicy):
device=None,
env_spec_eval=None,
eval_episodes=10,
full_covariance=False,
):
self.clip_range = clip_range
device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.device = device
@ -41,15 +44,29 @@ class PPO(OnPolicy):
obs_space = env.observation_space
act_space = env.action_space
self.discrete = is_discrete_space(act_space)
self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
actor_net = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device)
self.actor = ProbabilisticActor(
module=actor_net,
in_keys=["loc", "scale"],
out_keys=["action"],
distribution_class=torch.distributions.Normal,
return_log_prob=True
)
self.actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
if self.discrete:
distribution_class = torch.distributions.Categorical
distribution_kwargs = {"logits": "action_logits"}
else:
if full_covariance:
distribution_class = torch.distributions.MultivariateNormal
in_keys = ["loc", "scale_tril"]
else:
distribution_class = torch.distributions.Normal
in_keys = ["loc", "scale"]
self.prob_actor = ProbabilisticActor(
module=self.actor,
distribution_class=distribution_class,
return_log_prob=True,
in_keys=in_keys,
out_keys=["action"]
)
optimizers = {
"actor": torch.optim.Adam(self.actor.parameters(), lr=learning_rate),
@ -86,7 +103,7 @@ class PPO(OnPolicy):
self.loss_module = ClipPPOLoss(
actor_network=self.actor,
critic_network=self.critic,
clip_epsilon=clip_range,
clip_epsilon=self.clip_range,
loss_critic_type='l2',
entropy_coef=self.entropy_coef,
critic_coef=self.critic_coef,

View File

@ -10,7 +10,36 @@ from fancy_rl.algos.on_policy import OnPolicy
from fancy_rl.policy import Actor, Critic
from fancy_rl.projections import get_projection, BaseProjection
from fancy_rl.objectives import TRPLLoss
from fancy_rl.utils import is_discrete_space
from copy import deepcopy
from tensordict.nn import TensorDictModule
from tensordict import TensorDict
class ProjectedActor(TensorDictModule):
def __init__(self, raw_actor, old_actor, projection):
combined_module = self.CombinedModule(raw_actor, old_actor, projection)
super().__init__(
module=combined_module,
in_keys=raw_actor.in_keys,
out_keys=raw_actor.out_keys
)
self.raw_actor = raw_actor
self.old_actor = old_actor
self.projection = projection
class CombinedModule(nn.Module):
def __init__(self, raw_actor, old_actor, projection):
super().__init__()
self.raw_actor = raw_actor
self.old_actor = old_actor
self.projection = projection
def forward(self, tensordict):
raw_params = self.raw_actor(tensordict)
old_params = self.old_actor(tensordict)
combined_params = TensorDict({**raw_params, **{f"old_{key}": value for key, value in old_params.items()}}, batch_size=tensordict.batch_size)
projected_params = self.projection(combined_params)
return projected_params
class TRPL(OnPolicy):
def __init__(
@ -40,6 +69,7 @@ class TRPL(OnPolicy):
device=None,
env_spec_eval=None,
eval_episodes=10,
full_covariance=False,
):
device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.device = device
@ -50,8 +80,11 @@ class TRPL(OnPolicy):
obs_space = env.observation_space
act_space = env.action_space
assert not is_discrete_space(act_space), "TRPL does not support discrete action spaces"
self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
actor_net = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device)
self.raw_actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
self.old_actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
# Handle projection_class
if isinstance(projection_class, str):
@ -60,20 +93,27 @@ class TRPL(OnPolicy):
raise ValueError("projection_class must be a string or a subclass of BaseProjection")
self.projection = projection_class(
in_keys=["loc", "scale"],
out_keys=["loc", "scale"],
trust_region_bound_mean=trust_region_bound_mean,
trust_region_bound_cov=trust_region_bound_cov
in_keys=["loc", "scale_tril", "old_loc", "old_scale_tril"] if full_covariance else ["loc", "scale", "old_loc", "old_scale"],
out_keys=["loc", "scale_tril"] if full_covariance else ["loc", "scale"],
mean_bound=trust_region_bound_mean,
cov_bound=trust_region_bound_cov
)
self.actor = ProbabilisticActor(
module=actor_net,
in_keys=["observation"],
out_keys=["loc", "scale"],
distribution_class=torch.distributions.Normal,
return_log_prob=True
self.actor = ProjectedActor(self.raw_actor, self.old_actor, self.projection)
if full_covariance:
distribution_class = torch.distributions.MultivariateNormal
distribution_kwargs = {"loc": "loc", "scale_tril": "scale_tril"}
else:
distribution_class = torch.distributions.Normal
distribution_kwargs = {"loc": "loc", "scale": "scale"}
self.prob_actor = ProbabilisticActor(
module=self.actor,
distribution_class=distribution_class,
return_log_prob=True,
in_keys=distribution_kwargs,
)
self.old_actor = deepcopy(self.actor)
self.trust_region_coef = trust_region_coef
self.loss_module = TRPLLoss(
@ -88,7 +128,7 @@ class TRPL(OnPolicy):
)
optimizers = {
"actor": torch.optim.Adam(self.actor.parameters(), lr=learning_rate),
"actor": torch.optim.Adam(self.raw_actor.parameters(), lr=learning_rate),
"critic": torch.optim.Adam(self.critic.parameters(), lr=learning_rate)
}
@ -119,23 +159,7 @@ class TRPL(OnPolicy):
)
def update_old_policy(self):
self.old_actor.load_state_dict(self.actor.state_dict())
def project_policy(self, obs):
with torch.no_grad():
old_dist = self.old_actor(obs)
new_dist = self.actor(obs)
projected_params = self.projection.project(new_dist, old_dist)
return projected_params
def pre_update(self, tensordict):
obs = tensordict["observation"]
projected_dist = self.project_policy(obs)
# Update tensordict with projected distribution parameters
tensordict["projected_loc"] = projected_dist[0]
tensordict["projected_scale"] = projected_dist[1]
return tensordict
self.old_actor.load_state_dict(self.raw_actor.state_dict())
def post_update(self):
self.update_old_policy()