Rework algo impls
This commit is contained in:
parent
651ef1522f
commit
0c6e58634f
@ -3,6 +3,7 @@ import gymnasium as gym
|
|||||||
from torchrl.envs.libs.gym import GymWrapper
|
from torchrl.envs.libs.gym import GymWrapper
|
||||||
from torchrl.record import VideoRecorder
|
from torchrl.record import VideoRecorder
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
|
from tensordict import TensorDict
|
||||||
|
|
||||||
from fancy_rl.loggers import TerminalLogger
|
from fancy_rl.loggers import TerminalLogger
|
||||||
|
|
||||||
@ -53,12 +54,11 @@ class Algo(ABC):
|
|||||||
env = GymWrapper(env).to(self.device)
|
env = GymWrapper(env).to(self.device)
|
||||||
elif callable(env_spec):
|
elif callable(env_spec):
|
||||||
env = env_spec()
|
env = env_spec()
|
||||||
if isinstance(env, gym.Env):
|
if not (isinstance(env, gym.Env) or isinstance(env, gym.core.Wrapper)):
|
||||||
env = GymWrapper(env).to(self.device)
|
raise ValueError("env_spec must be a string or a callable that returns an environment. Was a callable that returned a {}".format(type(env)))
|
||||||
elif isinstance(env, gym.Env):
|
|
||||||
env = GymWrapper(env).to(self.device)
|
env = GymWrapper(env).to(self.device)
|
||||||
else:
|
else:
|
||||||
raise ValueError("env_spec must be a string or a callable that returns an environment.")
|
raise ValueError("env_spec must be a string or a callable that returns an environment. Was a {}".format(type(env_spec)))
|
||||||
return env
|
return env
|
||||||
|
|
||||||
def train_step(self, batch):
|
def train_step(self, batch):
|
||||||
@ -70,6 +70,20 @@ class Algo(ABC):
|
|||||||
def evaluate(self, epoch):
|
def evaluate(self, epoch):
|
||||||
raise NotImplementedError("evaluate method must be implemented in subclass.")
|
raise NotImplementedError("evaluate method must be implemented in subclass.")
|
||||||
|
|
||||||
def dump_video(module):
|
def predict(
|
||||||
if isinstance(module, VideoRecorder):
|
self,
|
||||||
module.dump()
|
observation,
|
||||||
|
state=None,
|
||||||
|
deterministic=False
|
||||||
|
):
|
||||||
|
with torch.no_grad():
|
||||||
|
obs_tensor = torch.as_tensor(observation, device=self.device).unsqueeze(0)
|
||||||
|
td = TensorDict({"observation": obs_tensor}, batch_size=[1])
|
||||||
|
|
||||||
|
action_td = self.prob_actor(td)
|
||||||
|
action = action_td["action"]
|
||||||
|
|
||||||
|
# We're not using recurrent policies, so we'll always return None for the state
|
||||||
|
next_state = None
|
||||||
|
|
||||||
|
return action.squeeze(0).cpu().numpy(), next_state
|
@ -55,7 +55,7 @@ class OnPolicy(Algo):
|
|||||||
# Create collector
|
# Create collector
|
||||||
self.collector = SyncDataCollector(
|
self.collector = SyncDataCollector(
|
||||||
create_env_fn=lambda: self.make_env(eval=False),
|
create_env_fn=lambda: self.make_env(eval=False),
|
||||||
policy=self.actor,
|
policy=self.prob_actor,
|
||||||
frames_per_batch=self.n_steps,
|
frames_per_batch=self.n_steps,
|
||||||
total_frames=self.total_timesteps,
|
total_frames=self.total_timesteps,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
|
@ -4,7 +4,7 @@ from torchrl.objectives import ClipPPOLoss
|
|||||||
from torchrl.objectives.value.advantages import GAE
|
from torchrl.objectives.value.advantages import GAE
|
||||||
from fancy_rl.algos.on_policy import OnPolicy
|
from fancy_rl.algos.on_policy import OnPolicy
|
||||||
from fancy_rl.policy import Actor, Critic
|
from fancy_rl.policy import Actor, Critic
|
||||||
from fancy_rl.projections import get_projection # Updated import
|
from fancy_rl.utils import is_discrete_space
|
||||||
|
|
||||||
class PPO(OnPolicy):
|
class PPO(OnPolicy):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -31,7 +31,10 @@ class PPO(OnPolicy):
|
|||||||
device=None,
|
device=None,
|
||||||
env_spec_eval=None,
|
env_spec_eval=None,
|
||||||
eval_episodes=10,
|
eval_episodes=10,
|
||||||
|
full_covariance=False,
|
||||||
):
|
):
|
||||||
|
self.clip_range = clip_range
|
||||||
|
|
||||||
device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
@ -41,15 +44,29 @@ class PPO(OnPolicy):
|
|||||||
obs_space = env.observation_space
|
obs_space = env.observation_space
|
||||||
act_space = env.action_space
|
act_space = env.action_space
|
||||||
|
|
||||||
|
self.discrete = is_discrete_space(act_space)
|
||||||
|
|
||||||
self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
|
self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
|
||||||
actor_net = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device)
|
self.actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
|
||||||
self.actor = ProbabilisticActor(
|
|
||||||
module=actor_net,
|
if self.discrete:
|
||||||
in_keys=["loc", "scale"],
|
distribution_class = torch.distributions.Categorical
|
||||||
out_keys=["action"],
|
distribution_kwargs = {"logits": "action_logits"}
|
||||||
distribution_class=torch.distributions.Normal,
|
else:
|
||||||
return_log_prob=True
|
if full_covariance:
|
||||||
)
|
distribution_class = torch.distributions.MultivariateNormal
|
||||||
|
in_keys = ["loc", "scale_tril"]
|
||||||
|
else:
|
||||||
|
distribution_class = torch.distributions.Normal
|
||||||
|
in_keys = ["loc", "scale"]
|
||||||
|
|
||||||
|
self.prob_actor = ProbabilisticActor(
|
||||||
|
module=self.actor,
|
||||||
|
distribution_class=distribution_class,
|
||||||
|
return_log_prob=True,
|
||||||
|
in_keys=in_keys,
|
||||||
|
out_keys=["action"]
|
||||||
|
)
|
||||||
|
|
||||||
optimizers = {
|
optimizers = {
|
||||||
"actor": torch.optim.Adam(self.actor.parameters(), lr=learning_rate),
|
"actor": torch.optim.Adam(self.actor.parameters(), lr=learning_rate),
|
||||||
@ -86,7 +103,7 @@ class PPO(OnPolicy):
|
|||||||
self.loss_module = ClipPPOLoss(
|
self.loss_module = ClipPPOLoss(
|
||||||
actor_network=self.actor,
|
actor_network=self.actor,
|
||||||
critic_network=self.critic,
|
critic_network=self.critic,
|
||||||
clip_epsilon=clip_range,
|
clip_epsilon=self.clip_range,
|
||||||
loss_critic_type='l2',
|
loss_critic_type='l2',
|
||||||
entropy_coef=self.entropy_coef,
|
entropy_coef=self.entropy_coef,
|
||||||
critic_coef=self.critic_coef,
|
critic_coef=self.critic_coef,
|
||||||
|
@ -10,7 +10,36 @@ from fancy_rl.algos.on_policy import OnPolicy
|
|||||||
from fancy_rl.policy import Actor, Critic
|
from fancy_rl.policy import Actor, Critic
|
||||||
from fancy_rl.projections import get_projection, BaseProjection
|
from fancy_rl.projections import get_projection, BaseProjection
|
||||||
from fancy_rl.objectives import TRPLLoss
|
from fancy_rl.objectives import TRPLLoss
|
||||||
|
from fancy_rl.utils import is_discrete_space
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
from tensordict.nn import TensorDictModule
|
||||||
|
from tensordict import TensorDict
|
||||||
|
|
||||||
|
class ProjectedActor(TensorDictModule):
|
||||||
|
def __init__(self, raw_actor, old_actor, projection):
|
||||||
|
combined_module = self.CombinedModule(raw_actor, old_actor, projection)
|
||||||
|
super().__init__(
|
||||||
|
module=combined_module,
|
||||||
|
in_keys=raw_actor.in_keys,
|
||||||
|
out_keys=raw_actor.out_keys
|
||||||
|
)
|
||||||
|
self.raw_actor = raw_actor
|
||||||
|
self.old_actor = old_actor
|
||||||
|
self.projection = projection
|
||||||
|
|
||||||
|
class CombinedModule(nn.Module):
|
||||||
|
def __init__(self, raw_actor, old_actor, projection):
|
||||||
|
super().__init__()
|
||||||
|
self.raw_actor = raw_actor
|
||||||
|
self.old_actor = old_actor
|
||||||
|
self.projection = projection
|
||||||
|
|
||||||
|
def forward(self, tensordict):
|
||||||
|
raw_params = self.raw_actor(tensordict)
|
||||||
|
old_params = self.old_actor(tensordict)
|
||||||
|
combined_params = TensorDict({**raw_params, **{f"old_{key}": value for key, value in old_params.items()}}, batch_size=tensordict.batch_size)
|
||||||
|
projected_params = self.projection(combined_params)
|
||||||
|
return projected_params
|
||||||
|
|
||||||
class TRPL(OnPolicy):
|
class TRPL(OnPolicy):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -40,6 +69,7 @@ class TRPL(OnPolicy):
|
|||||||
device=None,
|
device=None,
|
||||||
env_spec_eval=None,
|
env_spec_eval=None,
|
||||||
eval_episodes=10,
|
eval_episodes=10,
|
||||||
|
full_covariance=False,
|
||||||
):
|
):
|
||||||
device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
self.device = device
|
self.device = device
|
||||||
@ -50,8 +80,11 @@ class TRPL(OnPolicy):
|
|||||||
obs_space = env.observation_space
|
obs_space = env.observation_space
|
||||||
act_space = env.action_space
|
act_space = env.action_space
|
||||||
|
|
||||||
|
assert not is_discrete_space(act_space), "TRPL does not support discrete action spaces"
|
||||||
|
|
||||||
self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
|
self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
|
||||||
actor_net = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device)
|
self.raw_actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
|
||||||
|
self.old_actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
|
||||||
|
|
||||||
# Handle projection_class
|
# Handle projection_class
|
||||||
if isinstance(projection_class, str):
|
if isinstance(projection_class, str):
|
||||||
@ -60,20 +93,27 @@ class TRPL(OnPolicy):
|
|||||||
raise ValueError("projection_class must be a string or a subclass of BaseProjection")
|
raise ValueError("projection_class must be a string or a subclass of BaseProjection")
|
||||||
|
|
||||||
self.projection = projection_class(
|
self.projection = projection_class(
|
||||||
in_keys=["loc", "scale"],
|
in_keys=["loc", "scale_tril", "old_loc", "old_scale_tril"] if full_covariance else ["loc", "scale", "old_loc", "old_scale"],
|
||||||
out_keys=["loc", "scale"],
|
out_keys=["loc", "scale_tril"] if full_covariance else ["loc", "scale"],
|
||||||
trust_region_bound_mean=trust_region_bound_mean,
|
mean_bound=trust_region_bound_mean,
|
||||||
trust_region_bound_cov=trust_region_bound_cov
|
cov_bound=trust_region_bound_cov
|
||||||
)
|
)
|
||||||
|
|
||||||
self.actor = ProbabilisticActor(
|
self.actor = ProjectedActor(self.raw_actor, self.old_actor, self.projection)
|
||||||
module=actor_net,
|
|
||||||
in_keys=["observation"],
|
if full_covariance:
|
||||||
out_keys=["loc", "scale"],
|
distribution_class = torch.distributions.MultivariateNormal
|
||||||
distribution_class=torch.distributions.Normal,
|
distribution_kwargs = {"loc": "loc", "scale_tril": "scale_tril"}
|
||||||
return_log_prob=True
|
else:
|
||||||
|
distribution_class = torch.distributions.Normal
|
||||||
|
distribution_kwargs = {"loc": "loc", "scale": "scale"}
|
||||||
|
|
||||||
|
self.prob_actor = ProbabilisticActor(
|
||||||
|
module=self.actor,
|
||||||
|
distribution_class=distribution_class,
|
||||||
|
return_log_prob=True,
|
||||||
|
in_keys=distribution_kwargs,
|
||||||
)
|
)
|
||||||
self.old_actor = deepcopy(self.actor)
|
|
||||||
|
|
||||||
self.trust_region_coef = trust_region_coef
|
self.trust_region_coef = trust_region_coef
|
||||||
self.loss_module = TRPLLoss(
|
self.loss_module = TRPLLoss(
|
||||||
@ -88,7 +128,7 @@ class TRPL(OnPolicy):
|
|||||||
)
|
)
|
||||||
|
|
||||||
optimizers = {
|
optimizers = {
|
||||||
"actor": torch.optim.Adam(self.actor.parameters(), lr=learning_rate),
|
"actor": torch.optim.Adam(self.raw_actor.parameters(), lr=learning_rate),
|
||||||
"critic": torch.optim.Adam(self.critic.parameters(), lr=learning_rate)
|
"critic": torch.optim.Adam(self.critic.parameters(), lr=learning_rate)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -119,23 +159,7 @@ class TRPL(OnPolicy):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def update_old_policy(self):
|
def update_old_policy(self):
|
||||||
self.old_actor.load_state_dict(self.actor.state_dict())
|
self.old_actor.load_state_dict(self.raw_actor.state_dict())
|
||||||
|
|
||||||
def project_policy(self, obs):
|
|
||||||
with torch.no_grad():
|
|
||||||
old_dist = self.old_actor(obs)
|
|
||||||
new_dist = self.actor(obs)
|
|
||||||
projected_params = self.projection.project(new_dist, old_dist)
|
|
||||||
return projected_params
|
|
||||||
|
|
||||||
def pre_update(self, tensordict):
|
|
||||||
obs = tensordict["observation"]
|
|
||||||
projected_dist = self.project_policy(obs)
|
|
||||||
|
|
||||||
# Update tensordict with projected distribution parameters
|
|
||||||
tensordict["projected_loc"] = projected_dist[0]
|
|
||||||
tensordict["projected_scale"] = projected_dist[1]
|
|
||||||
return tensordict
|
|
||||||
|
|
||||||
def post_update(self):
|
def post_update(self):
|
||||||
self.update_old_policy()
|
self.update_old_policy()
|
Loading…
Reference in New Issue
Block a user