Use loggers correclty

This commit is contained in:
Dominik Moritz Roth 2024-06-02 14:13:36 +02:00
parent 4091df45f5
commit 65c6a950aa
2 changed files with 6 additions and 5 deletions

View File

@ -9,6 +9,8 @@ from torchrl.record import VideoRecorder
from tensordict import LazyStackedTensorDict, TensorDict
from abc import ABC
from fancy_rl.loggers import TerminalLogger
class OnPolicy(ABC):
def __init__(
self,
@ -32,7 +34,7 @@ class OnPolicy(ABC):
):
self.env_spec = env_spec
self.env_spec_eval = env_spec_eval if env_spec_eval is not None else env_spec
self.loggers = loggers
self.loggers = loggers if loggers != None else [TerminalLogger(None, None)]
self.optimizers = optimizers
self.learning_rate = learning_rate
self.n_steps = n_steps
@ -110,7 +112,7 @@ class OnPolicy(ABC):
batch = batch.to(self.device)
loss = self.train_step(batch)
for logger in self.loggers:
logger.log_scalar({"loss": loss.item()}, step=collected_frames)
logger.log_scalar("loss", loss.item(), step=collected_frames)
if (t + 1) % self.eval_interval == 0:
self.evaluate(t)

View File

@ -1,5 +1,5 @@
import torch
from torchrl.modules import ActorValueOperator, ProbabilisticActor
from torchrl.modules import ProbabilisticActor
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value.advantages import GAE
from fancy_rl.algos.on_policy import OnPolicy
@ -9,12 +9,11 @@ class PPO(OnPolicy):
def __init__(
self,
env_spec,
loggers=[],
loggers=None,
actor_hidden_sizes=[64, 64],
critic_hidden_sizes=[64, 64],
actor_activation_fn="Tanh",
critic_activation_fn="Tanh",
shared_stem_sizes=[64],
learning_rate=3e-4,
n_steps=2048,
batch_size=64,