Use loggers correclty
This commit is contained in:
		
							parent
							
								
									4091df45f5
								
							
						
					
					
						commit
						65c6a950aa
					
				@ -9,6 +9,8 @@ from torchrl.record import VideoRecorder
 | 
				
			|||||||
from tensordict import LazyStackedTensorDict, TensorDict
 | 
					from tensordict import LazyStackedTensorDict, TensorDict
 | 
				
			||||||
from abc import ABC
 | 
					from abc import ABC
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from fancy_rl.loggers import TerminalLogger
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class OnPolicy(ABC):
 | 
					class OnPolicy(ABC):
 | 
				
			||||||
    def __init__(
 | 
					    def __init__(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
@ -32,7 +34,7 @@ class OnPolicy(ABC):
 | 
				
			|||||||
    ):
 | 
					    ):
 | 
				
			||||||
        self.env_spec = env_spec
 | 
					        self.env_spec = env_spec
 | 
				
			||||||
        self.env_spec_eval = env_spec_eval if env_spec_eval is not None else env_spec
 | 
					        self.env_spec_eval = env_spec_eval if env_spec_eval is not None else env_spec
 | 
				
			||||||
        self.loggers = loggers
 | 
					        self.loggers = loggers if loggers != None else [TerminalLogger(None, None)]
 | 
				
			||||||
        self.optimizers = optimizers
 | 
					        self.optimizers = optimizers
 | 
				
			||||||
        self.learning_rate = learning_rate
 | 
					        self.learning_rate = learning_rate
 | 
				
			||||||
        self.n_steps = n_steps
 | 
					        self.n_steps = n_steps
 | 
				
			||||||
@ -110,7 +112,7 @@ class OnPolicy(ABC):
 | 
				
			|||||||
                    batch = batch.to(self.device)
 | 
					                    batch = batch.to(self.device)
 | 
				
			||||||
                    loss = self.train_step(batch)
 | 
					                    loss = self.train_step(batch)
 | 
				
			||||||
                    for logger in self.loggers:
 | 
					                    for logger in self.loggers:
 | 
				
			||||||
                        logger.log_scalar({"loss": loss.item()}, step=collected_frames)
 | 
					                        logger.log_scalar("loss", loss.item(), step=collected_frames)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if (t + 1) % self.eval_interval == 0:
 | 
					            if (t + 1) % self.eval_interval == 0:
 | 
				
			||||||
                self.evaluate(t)
 | 
					                self.evaluate(t)
 | 
				
			||||||
 | 
				
			|||||||
@ -1,5 +1,5 @@
 | 
				
			|||||||
import torch
 | 
					import torch
 | 
				
			||||||
from torchrl.modules import ActorValueOperator, ProbabilisticActor
 | 
					from torchrl.modules import ProbabilisticActor
 | 
				
			||||||
from torchrl.objectives import ClipPPOLoss
 | 
					from torchrl.objectives import ClipPPOLoss
 | 
				
			||||||
from torchrl.objectives.value.advantages import GAE
 | 
					from torchrl.objectives.value.advantages import GAE
 | 
				
			||||||
from fancy_rl.algos.on_policy import OnPolicy
 | 
					from fancy_rl.algos.on_policy import OnPolicy
 | 
				
			||||||
@ -9,12 +9,11 @@ class PPO(OnPolicy):
 | 
				
			|||||||
    def __init__(
 | 
					    def __init__(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
        env_spec,
 | 
					        env_spec,
 | 
				
			||||||
        loggers=[],
 | 
					        loggers=None,
 | 
				
			||||||
        actor_hidden_sizes=[64, 64],
 | 
					        actor_hidden_sizes=[64, 64],
 | 
				
			||||||
        critic_hidden_sizes=[64, 64],
 | 
					        critic_hidden_sizes=[64, 64],
 | 
				
			||||||
        actor_activation_fn="Tanh",
 | 
					        actor_activation_fn="Tanh",
 | 
				
			||||||
        critic_activation_fn="Tanh",
 | 
					        critic_activation_fn="Tanh",
 | 
				
			||||||
        shared_stem_sizes=[64],
 | 
					 | 
				
			||||||
        learning_rate=3e-4,
 | 
					        learning_rate=3e-4,
 | 
				
			||||||
        n_steps=2048,
 | 
					        n_steps=2048,
 | 
				
			||||||
        batch_size=64,
 | 
					        batch_size=64,
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
		Reference in New Issue
	
	Block a user