diff --git a/fancy_rl/algos/ppo.py b/fancy_rl/algos/ppo.py index 6481835..b70737e 100644 --- a/fancy_rl/algos/ppo.py +++ b/fancy_rl/algos/ppo.py @@ -66,23 +66,6 @@ class PPO(OnPolicy): "critic": torch.optim.Adam(self.critic.parameters(), lr=learning_rate) } - self.adv_module = GAE( - gamma=self.gamma, - lmbda=self.gae_lambda, - value_network=self.critic, - average_gae=False, - ) - - self.loss_module = ClipPPOLoss( - actor_network=self.actor, - critic_network=self.critic, - clip_epsilon=self.clip_range, - loss_critic_type='MSELoss', - entropy_coef=self.entropy_coef, - critic_coef=self.critic_coef, - normalize_advantage=self.normalize_advantage, - ) - super().__init__( policy=policy, env_spec=env_spec, @@ -104,4 +87,21 @@ class PPO(OnPolicy): device=device, env_spec_eval=env_spec_eval, eval_episodes=eval_episodes, + ) + + self.adv_module = GAE( + gamma=self.gamma, + lmbda=self.gae_lambda, + value_network=self.critic, + average_gae=False, + ) + + self.loss_module = ClipPPOLoss( + actor_network=self.actor, + critic_network=self.critic, + clip_epsilon=self.clip_range, + loss_critic_type='MSELoss', + entropy_coef=self.entropy_coef, + critic_coef=self.critic_coef, + normalize_advantage=self.normalize_advantage, ) \ No newline at end of file