Use loggers correclty
This commit is contained in:
parent
4091df45f5
commit
65c6a950aa
@ -9,6 +9,8 @@ from torchrl.record import VideoRecorder
|
||||
from tensordict import LazyStackedTensorDict, TensorDict
|
||||
from abc import ABC
|
||||
|
||||
from fancy_rl.loggers import TerminalLogger
|
||||
|
||||
class OnPolicy(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
@ -32,7 +34,7 @@ class OnPolicy(ABC):
|
||||
):
|
||||
self.env_spec = env_spec
|
||||
self.env_spec_eval = env_spec_eval if env_spec_eval is not None else env_spec
|
||||
self.loggers = loggers
|
||||
self.loggers = loggers if loggers != None else [TerminalLogger(None, None)]
|
||||
self.optimizers = optimizers
|
||||
self.learning_rate = learning_rate
|
||||
self.n_steps = n_steps
|
||||
@ -110,7 +112,7 @@ class OnPolicy(ABC):
|
||||
batch = batch.to(self.device)
|
||||
loss = self.train_step(batch)
|
||||
for logger in self.loggers:
|
||||
logger.log_scalar({"loss": loss.item()}, step=collected_frames)
|
||||
logger.log_scalar("loss", loss.item(), step=collected_frames)
|
||||
|
||||
if (t + 1) % self.eval_interval == 0:
|
||||
self.evaluate(t)
|
||||
|
@ -1,5 +1,5 @@
|
||||
import torch
|
||||
from torchrl.modules import ActorValueOperator, ProbabilisticActor
|
||||
from torchrl.modules import ProbabilisticActor
|
||||
from torchrl.objectives import ClipPPOLoss
|
||||
from torchrl.objectives.value.advantages import GAE
|
||||
from fancy_rl.algos.on_policy import OnPolicy
|
||||
@ -9,12 +9,11 @@ class PPO(OnPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
env_spec,
|
||||
loggers=[],
|
||||
loggers=None,
|
||||
actor_hidden_sizes=[64, 64],
|
||||
critic_hidden_sizes=[64, 64],
|
||||
actor_activation_fn="Tanh",
|
||||
critic_activation_fn="Tanh",
|
||||
shared_stem_sizes=[64],
|
||||
learning_rate=3e-4,
|
||||
n_steps=2048,
|
||||
batch_size=64,
|
||||
|
Loading…
Reference in New Issue
Block a user