Use loggers correclty

This commit is contained in:
Dominik Moritz Roth 2024-06-02 14:13:36 +02:00
parent 4091df45f5
commit 65c6a950aa
2 changed files with 6 additions and 5 deletions

View File

@ -9,6 +9,8 @@ from torchrl.record import VideoRecorder
from tensordict import LazyStackedTensorDict, TensorDict from tensordict import LazyStackedTensorDict, TensorDict
from abc import ABC from abc import ABC
from fancy_rl.loggers import TerminalLogger
class OnPolicy(ABC): class OnPolicy(ABC):
def __init__( def __init__(
self, self,
@ -32,7 +34,7 @@ class OnPolicy(ABC):
): ):
self.env_spec = env_spec self.env_spec = env_spec
self.env_spec_eval = env_spec_eval if env_spec_eval is not None else env_spec self.env_spec_eval = env_spec_eval if env_spec_eval is not None else env_spec
self.loggers = loggers self.loggers = loggers if loggers != None else [TerminalLogger(None, None)]
self.optimizers = optimizers self.optimizers = optimizers
self.learning_rate = learning_rate self.learning_rate = learning_rate
self.n_steps = n_steps self.n_steps = n_steps
@ -110,7 +112,7 @@ class OnPolicy(ABC):
batch = batch.to(self.device) batch = batch.to(self.device)
loss = self.train_step(batch) loss = self.train_step(batch)
for logger in self.loggers: for logger in self.loggers:
logger.log_scalar({"loss": loss.item()}, step=collected_frames) logger.log_scalar("loss", loss.item(), step=collected_frames)
if (t + 1) % self.eval_interval == 0: if (t + 1) % self.eval_interval == 0:
self.evaluate(t) self.evaluate(t)

View File

@ -1,5 +1,5 @@
import torch import torch
from torchrl.modules import ActorValueOperator, ProbabilisticActor from torchrl.modules import ProbabilisticActor
from torchrl.objectives import ClipPPOLoss from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value.advantages import GAE from torchrl.objectives.value.advantages import GAE
from fancy_rl.algos.on_policy import OnPolicy from fancy_rl.algos.on_policy import OnPolicy
@ -9,12 +9,11 @@ class PPO(OnPolicy):
def __init__( def __init__(
self, self,
env_spec, env_spec,
loggers=[], loggers=None,
actor_hidden_sizes=[64, 64], actor_hidden_sizes=[64, 64],
critic_hidden_sizes=[64, 64], critic_hidden_sizes=[64, 64],
actor_activation_fn="Tanh", actor_activation_fn="Tanh",
critic_activation_fn="Tanh", critic_activation_fn="Tanh",
shared_stem_sizes=[64],
learning_rate=3e-4, learning_rate=3e-4,
n_steps=2048, n_steps=2048,
batch_size=64, batch_size=64,