Compare commits

..

No commits in common. "7861821d0d463b27b05e9f4be3812aa27baa235f" and "4091df45f5c356b1f9c8127f93c1dc3e96f22247" have entirely different histories.

3 changed files with 54 additions and 40 deletions

View File

@ -9,8 +9,6 @@ from torchrl.record import VideoRecorder
from tensordict import LazyStackedTensorDict, TensorDict from tensordict import LazyStackedTensorDict, TensorDict
from abc import ABC from abc import ABC
from fancy_rl.loggers import TerminalLogger
class OnPolicy(ABC): class OnPolicy(ABC):
def __init__( def __init__(
self, self,
@ -34,7 +32,7 @@ class OnPolicy(ABC):
): ):
self.env_spec = env_spec self.env_spec = env_spec
self.env_spec_eval = env_spec_eval if env_spec_eval is not None else env_spec self.env_spec_eval = env_spec_eval if env_spec_eval is not None else env_spec
self.loggers = loggers if loggers != None else [TerminalLogger(None, None)] self.loggers = loggers
self.optimizers = optimizers self.optimizers = optimizers
self.learning_rate = learning_rate self.learning_rate = learning_rate
self.n_steps = n_steps self.n_steps = n_steps
@ -112,7 +110,7 @@ class OnPolicy(ABC):
batch = batch.to(self.device) batch = batch.to(self.device)
loss = self.train_step(batch) loss = self.train_step(batch)
for logger in self.loggers: for logger in self.loggers:
logger.log_scalar("loss", loss.item(), step=collected_frames) logger.log_scalar({"loss": loss.item()}, step=collected_frames)
if (t + 1) % self.eval_interval == 0: if (t + 1) % self.eval_interval == 0:
self.evaluate(t) self.evaluate(t)

View File

@ -1,5 +1,5 @@
import torch import torch
from torchrl.modules import ProbabilisticActor from torchrl.modules import ActorValueOperator, ProbabilisticActor
from torchrl.objectives import ClipPPOLoss from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value.advantages import GAE from torchrl.objectives.value.advantages import GAE
from fancy_rl.algos.on_policy import OnPolicy from fancy_rl.algos.on_policy import OnPolicy
@ -9,11 +9,12 @@ class PPO(OnPolicy):
def __init__( def __init__(
self, self,
env_spec, env_spec,
loggers=None, loggers=[],
actor_hidden_sizes=[64, 64], actor_hidden_sizes=[64, 64],
critic_hidden_sizes=[64, 64], critic_hidden_sizes=[64, 64],
actor_activation_fn="Tanh", actor_activation_fn="Tanh",
critic_activation_fn="Tanh", critic_activation_fn="Tanh",
shared_stem_sizes=[64],
learning_rate=3e-4, learning_rate=3e-4,
n_steps=2048, n_steps=2048,
batch_size=64, batch_size=64,

View File

@ -1,11 +1,11 @@
import torch import torch
from torchrl.modules import ProbabilisticActor from torchrl.modules import ActorValueOperator, ProbabilisticActor
from torchrl.objectives.value.advantages import GAE from torchrl.objectives.value.advantages import GAE
from fancy_rl.algos.on_policy import OnPolicy from fancy_rl.algos.on_policy import OnPolicy
from fancy_rl.policy import Actor, Critic from fancy_rl.policy import Actor, Critic, SharedModule
from fancy_rl.objectives import TRPLLoss from fancy_rl.objectives import TRPLLoss
class PPO(OnPolicy): class TRPL(OnPolicy):
def __init__( def __init__(
self, self,
env_spec, env_spec,
@ -14,6 +14,7 @@ class PPO(OnPolicy):
critic_hidden_sizes=[64, 64], critic_hidden_sizes=[64, 64],
actor_activation_fn="Tanh", actor_activation_fn="Tanh",
critic_activation_fn="Tanh", critic_activation_fn="Tanh",
shared_stem_sizes=[64],
proj_layer_type=None, proj_layer_type=None,
learning_rate=3e-4, learning_rate=3e-4,
n_steps=2048, n_steps=2048,
@ -27,16 +28,14 @@ class PPO(OnPolicy):
entropy_coef=0.01, entropy_coef=0.01,
critic_coef=0.5, critic_coef=0.5,
trust_region_coef=10.0, trust_region_coef=10.0,
normalize_advantage=False, normalize_advantage=True,
device=None, device=None,
env_spec_eval=None, env_spec_eval=None,
eval_episodes=10, eval_episodes=10,
): ):
device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.device = device
self.trust_region_layer = None # TODO: from proj_layer_type self.trust_region_layer = None # from proj_layer_type
self.trust_region_coef = trust_region_coef
# Initialize environment to get observation and action space sizes # Initialize environment to get observation and action space sizes
self.env_spec = env_spec self.env_spec = env_spec
@ -44,45 +43,38 @@ class PPO(OnPolicy):
obs_space = env.observation_space obs_space = env.observation_space
act_space = env.action_space act_space = env.action_space
self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device) # Define the shared, actor, and critic modules
actor_net = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device) self.shared_module = SharedModule(obs_space, shared_stem_sizes, actor_activation_fn, device)
raw_actor = ProbabilisticActor( self.raw_actor = Actor(self.shared_module, act_space, actor_hidden_sizes, actor_activation_fn, device)
module=actor_net, self.critic = Critic(self.shared_module, critic_hidden_sizes, critic_activation_fn, device)
# Perfrom projection
self.actor = self.raw_actor # TODO: Project
# Combine into an ActorValueOperator
self.ac_module = ActorValueOperator(
self.shared_module,
self.actor,
self.critic
)
# Define the policy as a ProbabilisticActor
policy = ProbabilisticActor(
module=self.ac_module.get_policy_operator(),
in_keys=["loc", "scale"], in_keys=["loc", "scale"],
out_keys=["action"], out_keys=["action"],
distribution_class=torch.distributions.Normal, distribution_class=torch.distributions.Normal,
return_log_prob=True return_log_prob=True
) )
self.actor = raw_actor # TODO: Proj here
optimizers = { optimizers = {
"actor": torch.optim.Adam(self.actor.parameters(), lr=learning_rate), "actor": torch.optim.Adam(self.actor.parameters(), lr=learning_rate),
"critic": torch.optim.Adam(self.critic.parameters(), lr=learning_rate) "critic": torch.optim.Adam(self.critic.parameters(), lr=learning_rate)
} }
super().__init__(
env_spec=env_spec,
loggers=loggers,
optimizers=optimizers,
learning_rate=learning_rate,
n_steps=n_steps,
batch_size=batch_size,
n_epochs=n_epochs,
gamma=gamma,
total_timesteps=total_timesteps,
eval_interval=eval_interval,
eval_deterministic=eval_deterministic,
entropy_coef=entropy_coef,
critic_coef=critic_coef,
normalize_advantage=normalize_advantage,
device=device,
env_spec_eval=env_spec_eval,
eval_episodes=eval_episodes,
)
self.adv_module = GAE( self.adv_module = GAE(
gamma=self.gamma, gamma=self.gamma,
lmbda=gae_lambda, lmbda=self.gae_lambda,
value_network=self.critic, value_network=self.critic,
average_gae=False, average_gae=False,
) )
@ -91,9 +83,32 @@ class PPO(OnPolicy):
actor_network=self.actor, actor_network=self.actor,
critic_network=self.critic, critic_network=self.critic,
trust_region_layer=self.trust_region_layer, trust_region_layer=self.trust_region_layer,
loss_critic_type='l2', loss_critic_type='MSELoss',
entropy_coef=self.entropy_coef, entropy_coef=self.entropy_coef,
critic_coef=self.critic_coef, critic_coef=self.critic_coef,
trust_region_coef=self.trust_region_coef, trust_region_coef=self.trust_region_coef,
normalize_advantage=self.normalize_advantage, normalize_advantage=self.normalize_advantage,
) )
super().__init__(
policy=policy,
env_spec=env_spec,
loggers=loggers,
optimizers=optimizers,
learning_rate=learning_rate,
n_steps=n_steps,
batch_size=batch_size,
n_epochs=n_epochs,
gamma=gamma,
gae_lambda=gae_lambda,
total_timesteps=total_timesteps,
eval_interval=eval_interval,
eval_deterministic=eval_deterministic,
entropy_coef=entropy_coef,
critic_coef=critic_coef,
normalize_advantage=normalize_advantage,
clip_range=clip_range,
device=device,
env_spec_eval=env_spec_eval,
eval_episodes=eval_episodes,
)