Compare commits

...

2 Commits

Author SHA1 Message Date
7861821d0d Worked on TRPL module 2024-06-02 14:14:12 +02:00
65c6a950aa Use loggers correclty 2024-06-02 14:13:36 +02:00
3 changed files with 37 additions and 51 deletions

View File

@ -9,6 +9,8 @@ from torchrl.record import VideoRecorder
from tensordict import LazyStackedTensorDict, TensorDict from tensordict import LazyStackedTensorDict, TensorDict
from abc import ABC from abc import ABC
from fancy_rl.loggers import TerminalLogger
class OnPolicy(ABC): class OnPolicy(ABC):
def __init__( def __init__(
self, self,
@ -32,7 +34,7 @@ class OnPolicy(ABC):
): ):
self.env_spec = env_spec self.env_spec = env_spec
self.env_spec_eval = env_spec_eval if env_spec_eval is not None else env_spec self.env_spec_eval = env_spec_eval if env_spec_eval is not None else env_spec
self.loggers = loggers self.loggers = loggers if loggers != None else [TerminalLogger(None, None)]
self.optimizers = optimizers self.optimizers = optimizers
self.learning_rate = learning_rate self.learning_rate = learning_rate
self.n_steps = n_steps self.n_steps = n_steps
@ -110,7 +112,7 @@ class OnPolicy(ABC):
batch = batch.to(self.device) batch = batch.to(self.device)
loss = self.train_step(batch) loss = self.train_step(batch)
for logger in self.loggers: for logger in self.loggers:
logger.log_scalar({"loss": loss.item()}, step=collected_frames) logger.log_scalar("loss", loss.item(), step=collected_frames)
if (t + 1) % self.eval_interval == 0: if (t + 1) % self.eval_interval == 0:
self.evaluate(t) self.evaluate(t)

View File

@ -1,5 +1,5 @@
import torch import torch
from torchrl.modules import ActorValueOperator, ProbabilisticActor from torchrl.modules import ProbabilisticActor
from torchrl.objectives import ClipPPOLoss from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value.advantages import GAE from torchrl.objectives.value.advantages import GAE
from fancy_rl.algos.on_policy import OnPolicy from fancy_rl.algos.on_policy import OnPolicy
@ -9,12 +9,11 @@ class PPO(OnPolicy):
def __init__( def __init__(
self, self,
env_spec, env_spec,
loggers=[], loggers=None,
actor_hidden_sizes=[64, 64], actor_hidden_sizes=[64, 64],
critic_hidden_sizes=[64, 64], critic_hidden_sizes=[64, 64],
actor_activation_fn="Tanh", actor_activation_fn="Tanh",
critic_activation_fn="Tanh", critic_activation_fn="Tanh",
shared_stem_sizes=[64],
learning_rate=3e-4, learning_rate=3e-4,
n_steps=2048, n_steps=2048,
batch_size=64, batch_size=64,

View File

@ -1,11 +1,11 @@
import torch import torch
from torchrl.modules import ActorValueOperator, ProbabilisticActor from torchrl.modules import ProbabilisticActor
from torchrl.objectives.value.advantages import GAE from torchrl.objectives.value.advantages import GAE
from fancy_rl.algos.on_policy import OnPolicy from fancy_rl.algos.on_policy import OnPolicy
from fancy_rl.policy import Actor, Critic, SharedModule from fancy_rl.policy import Actor, Critic
from fancy_rl.objectives import TRPLLoss from fancy_rl.objectives import TRPLLoss
class TRPL(OnPolicy): class PPO(OnPolicy):
def __init__( def __init__(
self, self,
env_spec, env_spec,
@ -14,7 +14,6 @@ class TRPL(OnPolicy):
critic_hidden_sizes=[64, 64], critic_hidden_sizes=[64, 64],
actor_activation_fn="Tanh", actor_activation_fn="Tanh",
critic_activation_fn="Tanh", critic_activation_fn="Tanh",
shared_stem_sizes=[64],
proj_layer_type=None, proj_layer_type=None,
learning_rate=3e-4, learning_rate=3e-4,
n_steps=2048, n_steps=2048,
@ -28,14 +27,16 @@ class TRPL(OnPolicy):
entropy_coef=0.01, entropy_coef=0.01,
critic_coef=0.5, critic_coef=0.5,
trust_region_coef=10.0, trust_region_coef=10.0,
normalize_advantage=True, normalize_advantage=False,
device=None, device=None,
env_spec_eval=None, env_spec_eval=None,
eval_episodes=10, eval_episodes=10,
): ):
device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.device = device
self.trust_region_layer = None # from proj_layer_type self.trust_region_layer = None # TODO: from proj_layer_type
self.trust_region_coef = trust_region_coef
# Initialize environment to get observation and action space sizes # Initialize environment to get observation and action space sizes
self.env_spec = env_spec self.env_spec = env_spec
@ -43,55 +44,23 @@ class TRPL(OnPolicy):
obs_space = env.observation_space obs_space = env.observation_space
act_space = env.action_space act_space = env.action_space
# Define the shared, actor, and critic modules self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
self.shared_module = SharedModule(obs_space, shared_stem_sizes, actor_activation_fn, device) actor_net = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device)
self.raw_actor = Actor(self.shared_module, act_space, actor_hidden_sizes, actor_activation_fn, device) raw_actor = ProbabilisticActor(
self.critic = Critic(self.shared_module, critic_hidden_sizes, critic_activation_fn, device) module=actor_net,
# Perfrom projection
self.actor = self.raw_actor # TODO: Project
# Combine into an ActorValueOperator
self.ac_module = ActorValueOperator(
self.shared_module,
self.actor,
self.critic
)
# Define the policy as a ProbabilisticActor
policy = ProbabilisticActor(
module=self.ac_module.get_policy_operator(),
in_keys=["loc", "scale"], in_keys=["loc", "scale"],
out_keys=["action"], out_keys=["action"],
distribution_class=torch.distributions.Normal, distribution_class=torch.distributions.Normal,
return_log_prob=True return_log_prob=True
) )
self.actor = raw_actor # TODO: Proj here
optimizers = { optimizers = {
"actor": torch.optim.Adam(self.actor.parameters(), lr=learning_rate), "actor": torch.optim.Adam(self.actor.parameters(), lr=learning_rate),
"critic": torch.optim.Adam(self.critic.parameters(), lr=learning_rate) "critic": torch.optim.Adam(self.critic.parameters(), lr=learning_rate)
} }
self.adv_module = GAE(
gamma=self.gamma,
lmbda=self.gae_lambda,
value_network=self.critic,
average_gae=False,
)
self.loss_module = TRPLLoss(
actor_network=self.actor,
critic_network=self.critic,
trust_region_layer=self.trust_region_layer,
loss_critic_type='MSELoss',
entropy_coef=self.entropy_coef,
critic_coef=self.critic_coef,
trust_region_coef=self.trust_region_coef,
normalize_advantage=self.normalize_advantage,
)
super().__init__( super().__init__(
policy=policy,
env_spec=env_spec, env_spec=env_spec,
loggers=loggers, loggers=loggers,
optimizers=optimizers, optimizers=optimizers,
@ -100,15 +69,31 @@ class TRPL(OnPolicy):
batch_size=batch_size, batch_size=batch_size,
n_epochs=n_epochs, n_epochs=n_epochs,
gamma=gamma, gamma=gamma,
gae_lambda=gae_lambda,
total_timesteps=total_timesteps, total_timesteps=total_timesteps,
eval_interval=eval_interval, eval_interval=eval_interval,
eval_deterministic=eval_deterministic, eval_deterministic=eval_deterministic,
entropy_coef=entropy_coef, entropy_coef=entropy_coef,
critic_coef=critic_coef, critic_coef=critic_coef,
normalize_advantage=normalize_advantage, normalize_advantage=normalize_advantage,
clip_range=clip_range,
device=device, device=device,
env_spec_eval=env_spec_eval, env_spec_eval=env_spec_eval,
eval_episodes=eval_episodes, eval_episodes=eval_episodes,
) )
self.adv_module = GAE(
gamma=self.gamma,
lmbda=gae_lambda,
value_network=self.critic,
average_gae=False,
)
self.loss_module = TRPLLoss(
actor_network=self.actor,
critic_network=self.critic,
trust_region_layer=self.trust_region_layer,
loss_critic_type='l2',
entropy_coef=self.entropy_coef,
critic_coef=self.critic_coef,
trust_region_coef=self.trust_region_coef,
normalize_advantage=self.normalize_advantage,
)