Use loggers correclty
This commit is contained in:
parent
4091df45f5
commit
65c6a950aa
@ -9,6 +9,8 @@ from torchrl.record import VideoRecorder
|
|||||||
from tensordict import LazyStackedTensorDict, TensorDict
|
from tensordict import LazyStackedTensorDict, TensorDict
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
|
|
||||||
|
from fancy_rl.loggers import TerminalLogger
|
||||||
|
|
||||||
class OnPolicy(ABC):
|
class OnPolicy(ABC):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -32,7 +34,7 @@ class OnPolicy(ABC):
|
|||||||
):
|
):
|
||||||
self.env_spec = env_spec
|
self.env_spec = env_spec
|
||||||
self.env_spec_eval = env_spec_eval if env_spec_eval is not None else env_spec
|
self.env_spec_eval = env_spec_eval if env_spec_eval is not None else env_spec
|
||||||
self.loggers = loggers
|
self.loggers = loggers if loggers != None else [TerminalLogger(None, None)]
|
||||||
self.optimizers = optimizers
|
self.optimizers = optimizers
|
||||||
self.learning_rate = learning_rate
|
self.learning_rate = learning_rate
|
||||||
self.n_steps = n_steps
|
self.n_steps = n_steps
|
||||||
@ -110,7 +112,7 @@ class OnPolicy(ABC):
|
|||||||
batch = batch.to(self.device)
|
batch = batch.to(self.device)
|
||||||
loss = self.train_step(batch)
|
loss = self.train_step(batch)
|
||||||
for logger in self.loggers:
|
for logger in self.loggers:
|
||||||
logger.log_scalar({"loss": loss.item()}, step=collected_frames)
|
logger.log_scalar("loss", loss.item(), step=collected_frames)
|
||||||
|
|
||||||
if (t + 1) % self.eval_interval == 0:
|
if (t + 1) % self.eval_interval == 0:
|
||||||
self.evaluate(t)
|
self.evaluate(t)
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
from torchrl.modules import ActorValueOperator, ProbabilisticActor
|
from torchrl.modules import ProbabilisticActor
|
||||||
from torchrl.objectives import ClipPPOLoss
|
from torchrl.objectives import ClipPPOLoss
|
||||||
from torchrl.objectives.value.advantages import GAE
|
from torchrl.objectives.value.advantages import GAE
|
||||||
from fancy_rl.algos.on_policy import OnPolicy
|
from fancy_rl.algos.on_policy import OnPolicy
|
||||||
@ -9,12 +9,11 @@ class PPO(OnPolicy):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
env_spec,
|
env_spec,
|
||||||
loggers=[],
|
loggers=None,
|
||||||
actor_hidden_sizes=[64, 64],
|
actor_hidden_sizes=[64, 64],
|
||||||
critic_hidden_sizes=[64, 64],
|
critic_hidden_sizes=[64, 64],
|
||||||
actor_activation_fn="Tanh",
|
actor_activation_fn="Tanh",
|
||||||
critic_activation_fn="Tanh",
|
critic_activation_fn="Tanh",
|
||||||
shared_stem_sizes=[64],
|
|
||||||
learning_rate=3e-4,
|
learning_rate=3e-4,
|
||||||
n_steps=2048,
|
n_steps=2048,
|
||||||
batch_size=64,
|
batch_size=64,
|
||||||
|
Loading…
Reference in New Issue
Block a user