Rework algo impls
This commit is contained in:
parent
651ef1522f
commit
0c6e58634f
@ -3,6 +3,7 @@ import gymnasium as gym
|
||||
from torchrl.envs.libs.gym import GymWrapper
|
||||
from torchrl.record import VideoRecorder
|
||||
from abc import ABC
|
||||
from tensordict import TensorDict
|
||||
|
||||
from fancy_rl.loggers import TerminalLogger
|
||||
|
||||
@ -53,12 +54,11 @@ class Algo(ABC):
|
||||
env = GymWrapper(env).to(self.device)
|
||||
elif callable(env_spec):
|
||||
env = env_spec()
|
||||
if isinstance(env, gym.Env):
|
||||
env = GymWrapper(env).to(self.device)
|
||||
elif isinstance(env, gym.Env):
|
||||
if not (isinstance(env, gym.Env) or isinstance(env, gym.core.Wrapper)):
|
||||
raise ValueError("env_spec must be a string or a callable that returns an environment. Was a callable that returned a {}".format(type(env)))
|
||||
env = GymWrapper(env).to(self.device)
|
||||
else:
|
||||
raise ValueError("env_spec must be a string or a callable that returns an environment.")
|
||||
raise ValueError("env_spec must be a string or a callable that returns an environment. Was a {}".format(type(env_spec)))
|
||||
return env
|
||||
|
||||
def train_step(self, batch):
|
||||
@ -70,6 +70,20 @@ class Algo(ABC):
|
||||
def evaluate(self, epoch):
|
||||
raise NotImplementedError("evaluate method must be implemented in subclass.")
|
||||
|
||||
def dump_video(module):
|
||||
if isinstance(module, VideoRecorder):
|
||||
module.dump()
|
||||
def predict(
|
||||
self,
|
||||
observation,
|
||||
state=None,
|
||||
deterministic=False
|
||||
):
|
||||
with torch.no_grad():
|
||||
obs_tensor = torch.as_tensor(observation, device=self.device).unsqueeze(0)
|
||||
td = TensorDict({"observation": obs_tensor}, batch_size=[1])
|
||||
|
||||
action_td = self.prob_actor(td)
|
||||
action = action_td["action"]
|
||||
|
||||
# We're not using recurrent policies, so we'll always return None for the state
|
||||
next_state = None
|
||||
|
||||
return action.squeeze(0).cpu().numpy(), next_state
|
@ -55,7 +55,7 @@ class OnPolicy(Algo):
|
||||
# Create collector
|
||||
self.collector = SyncDataCollector(
|
||||
create_env_fn=lambda: self.make_env(eval=False),
|
||||
policy=self.actor,
|
||||
policy=self.prob_actor,
|
||||
frames_per_batch=self.n_steps,
|
||||
total_frames=self.total_timesteps,
|
||||
device=self.device,
|
||||
|
@ -4,7 +4,7 @@ from torchrl.objectives import ClipPPOLoss
|
||||
from torchrl.objectives.value.advantages import GAE
|
||||
from fancy_rl.algos.on_policy import OnPolicy
|
||||
from fancy_rl.policy import Actor, Critic
|
||||
from fancy_rl.projections import get_projection # Updated import
|
||||
from fancy_rl.utils import is_discrete_space
|
||||
|
||||
class PPO(OnPolicy):
|
||||
def __init__(
|
||||
@ -31,7 +31,10 @@ class PPO(OnPolicy):
|
||||
device=None,
|
||||
env_spec_eval=None,
|
||||
eval_episodes=10,
|
||||
full_covariance=False,
|
||||
):
|
||||
self.clip_range = clip_range
|
||||
|
||||
device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.device = device
|
||||
|
||||
@ -41,15 +44,29 @@ class PPO(OnPolicy):
|
||||
obs_space = env.observation_space
|
||||
act_space = env.action_space
|
||||
|
||||
self.discrete = is_discrete_space(act_space)
|
||||
|
||||
self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
|
||||
actor_net = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device)
|
||||
self.actor = ProbabilisticActor(
|
||||
module=actor_net,
|
||||
in_keys=["loc", "scale"],
|
||||
out_keys=["action"],
|
||||
distribution_class=torch.distributions.Normal,
|
||||
return_log_prob=True
|
||||
)
|
||||
self.actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
|
||||
|
||||
if self.discrete:
|
||||
distribution_class = torch.distributions.Categorical
|
||||
distribution_kwargs = {"logits": "action_logits"}
|
||||
else:
|
||||
if full_covariance:
|
||||
distribution_class = torch.distributions.MultivariateNormal
|
||||
in_keys = ["loc", "scale_tril"]
|
||||
else:
|
||||
distribution_class = torch.distributions.Normal
|
||||
in_keys = ["loc", "scale"]
|
||||
|
||||
self.prob_actor = ProbabilisticActor(
|
||||
module=self.actor,
|
||||
distribution_class=distribution_class,
|
||||
return_log_prob=True,
|
||||
in_keys=in_keys,
|
||||
out_keys=["action"]
|
||||
)
|
||||
|
||||
optimizers = {
|
||||
"actor": torch.optim.Adam(self.actor.parameters(), lr=learning_rate),
|
||||
@ -86,9 +103,9 @@ class PPO(OnPolicy):
|
||||
self.loss_module = ClipPPOLoss(
|
||||
actor_network=self.actor,
|
||||
critic_network=self.critic,
|
||||
clip_epsilon=clip_range,
|
||||
clip_epsilon=self.clip_range,
|
||||
loss_critic_type='l2',
|
||||
entropy_coef=self.entropy_coef,
|
||||
critic_coef=self.critic_coef,
|
||||
normalize_advantage=self.normalize_advantage,
|
||||
)
|
||||
)
|
@ -10,7 +10,36 @@ from fancy_rl.algos.on_policy import OnPolicy
|
||||
from fancy_rl.policy import Actor, Critic
|
||||
from fancy_rl.projections import get_projection, BaseProjection
|
||||
from fancy_rl.objectives import TRPLLoss
|
||||
from fancy_rl.utils import is_discrete_space
|
||||
from copy import deepcopy
|
||||
from tensordict.nn import TensorDictModule
|
||||
from tensordict import TensorDict
|
||||
|
||||
class ProjectedActor(TensorDictModule):
|
||||
def __init__(self, raw_actor, old_actor, projection):
|
||||
combined_module = self.CombinedModule(raw_actor, old_actor, projection)
|
||||
super().__init__(
|
||||
module=combined_module,
|
||||
in_keys=raw_actor.in_keys,
|
||||
out_keys=raw_actor.out_keys
|
||||
)
|
||||
self.raw_actor = raw_actor
|
||||
self.old_actor = old_actor
|
||||
self.projection = projection
|
||||
|
||||
class CombinedModule(nn.Module):
|
||||
def __init__(self, raw_actor, old_actor, projection):
|
||||
super().__init__()
|
||||
self.raw_actor = raw_actor
|
||||
self.old_actor = old_actor
|
||||
self.projection = projection
|
||||
|
||||
def forward(self, tensordict):
|
||||
raw_params = self.raw_actor(tensordict)
|
||||
old_params = self.old_actor(tensordict)
|
||||
combined_params = TensorDict({**raw_params, **{f"old_{key}": value for key, value in old_params.items()}}, batch_size=tensordict.batch_size)
|
||||
projected_params = self.projection(combined_params)
|
||||
return projected_params
|
||||
|
||||
class TRPL(OnPolicy):
|
||||
def __init__(
|
||||
@ -40,6 +69,7 @@ class TRPL(OnPolicy):
|
||||
device=None,
|
||||
env_spec_eval=None,
|
||||
eval_episodes=10,
|
||||
full_covariance=False,
|
||||
):
|
||||
device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.device = device
|
||||
@ -50,8 +80,11 @@ class TRPL(OnPolicy):
|
||||
obs_space = env.observation_space
|
||||
act_space = env.action_space
|
||||
|
||||
assert not is_discrete_space(act_space), "TRPL does not support discrete action spaces"
|
||||
|
||||
self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
|
||||
actor_net = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device)
|
||||
self.raw_actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
|
||||
self.old_actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
|
||||
|
||||
# Handle projection_class
|
||||
if isinstance(projection_class, str):
|
||||
@ -60,20 +93,27 @@ class TRPL(OnPolicy):
|
||||
raise ValueError("projection_class must be a string or a subclass of BaseProjection")
|
||||
|
||||
self.projection = projection_class(
|
||||
in_keys=["loc", "scale"],
|
||||
out_keys=["loc", "scale"],
|
||||
trust_region_bound_mean=trust_region_bound_mean,
|
||||
trust_region_bound_cov=trust_region_bound_cov
|
||||
in_keys=["loc", "scale_tril", "old_loc", "old_scale_tril"] if full_covariance else ["loc", "scale", "old_loc", "old_scale"],
|
||||
out_keys=["loc", "scale_tril"] if full_covariance else ["loc", "scale"],
|
||||
mean_bound=trust_region_bound_mean,
|
||||
cov_bound=trust_region_bound_cov
|
||||
)
|
||||
|
||||
self.actor = ProbabilisticActor(
|
||||
module=actor_net,
|
||||
in_keys=["observation"],
|
||||
out_keys=["loc", "scale"],
|
||||
distribution_class=torch.distributions.Normal,
|
||||
return_log_prob=True
|
||||
self.actor = ProjectedActor(self.raw_actor, self.old_actor, self.projection)
|
||||
|
||||
if full_covariance:
|
||||
distribution_class = torch.distributions.MultivariateNormal
|
||||
distribution_kwargs = {"loc": "loc", "scale_tril": "scale_tril"}
|
||||
else:
|
||||
distribution_class = torch.distributions.Normal
|
||||
distribution_kwargs = {"loc": "loc", "scale": "scale"}
|
||||
|
||||
self.prob_actor = ProbabilisticActor(
|
||||
module=self.actor,
|
||||
distribution_class=distribution_class,
|
||||
return_log_prob=True,
|
||||
in_keys=distribution_kwargs,
|
||||
)
|
||||
self.old_actor = deepcopy(self.actor)
|
||||
|
||||
self.trust_region_coef = trust_region_coef
|
||||
self.loss_module = TRPLLoss(
|
||||
@ -88,7 +128,7 @@ class TRPL(OnPolicy):
|
||||
)
|
||||
|
||||
optimizers = {
|
||||
"actor": torch.optim.Adam(self.actor.parameters(), lr=learning_rate),
|
||||
"actor": torch.optim.Adam(self.raw_actor.parameters(), lr=learning_rate),
|
||||
"critic": torch.optim.Adam(self.critic.parameters(), lr=learning_rate)
|
||||
}
|
||||
|
||||
@ -119,23 +159,7 @@ class TRPL(OnPolicy):
|
||||
)
|
||||
|
||||
def update_old_policy(self):
|
||||
self.old_actor.load_state_dict(self.actor.state_dict())
|
||||
|
||||
def project_policy(self, obs):
|
||||
with torch.no_grad():
|
||||
old_dist = self.old_actor(obs)
|
||||
new_dist = self.actor(obs)
|
||||
projected_params = self.projection.project(new_dist, old_dist)
|
||||
return projected_params
|
||||
|
||||
def pre_update(self, tensordict):
|
||||
obs = tensordict["observation"]
|
||||
projected_dist = self.project_policy(obs)
|
||||
|
||||
# Update tensordict with projected distribution parameters
|
||||
tensordict["projected_loc"] = projected_dist[0]
|
||||
tensordict["projected_scale"] = projected_dist[1]
|
||||
return tensordict
|
||||
self.old_actor.load_state_dict(self.raw_actor.state_dict())
|
||||
|
||||
def post_update(self):
|
||||
self.update_old_policy()
|
||||
self.update_old_policy()
|
Loading…
Reference in New Issue
Block a user