From 59060c753334a69cf19505343f13e47f3df952bc Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sun, 2 Jun 2024 11:07:46 +0200 Subject: [PATCH] refactor --- fancy_rl/ppo.py | 40 ++++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/fancy_rl/ppo.py b/fancy_rl/ppo.py index 0459a2d..9aeeed1 100644 --- a/fancy_rl/ppo.py +++ b/fancy_rl/ppo.py @@ -53,7 +53,7 @@ class PPO(OnPolicy): ) # Define the policy as a ProbabilisticActor - self.policy = ProbabilisticActor( + policy = ProbabilisticActor( module=self.ac_module.get_policy_operator(), in_keys=["loc", "scale"], out_keys=["action"], @@ -66,8 +66,25 @@ class PPO(OnPolicy): "critic": torch.optim.Adam(self.critic.parameters(), lr=learning_rate) } + self.adv_module = GAE( + gamma=self.gamma, + lmbda=self.gae_lambda, + value_network=self.critic, + average_gae=False, + ) + + self.loss_module = ClipPPOLoss( + actor_network=self.actor, + critic_network=self.critic, + clip_epsilon=self.clip_range, + loss_critic_type='MSELoss', + entropy_coef=self.entropy_coef, + critic_coef=self.critic_coef, + normalize_advantage=self.normalize_advantage, + ) + super().__init__( - policy=self.policy, + policy=policy, env_spec=env_spec, loggers=loggers, optimizers=optimizers, @@ -87,21 +104,4 @@ class PPO(OnPolicy): device=device, env_spec_eval=env_spec_eval, eval_episodes=eval_episodes, - ) - - self.adv_module = GAE( - gamma=self.gamma, - lmbda=self.gae_lambda, - value_network=self.critic, - average_gae=False, - ) - - self.loss_module = ClipPPOLoss( - actor_network=self.actor, - critic_network=self.critic, - clip_epsilon=self.clip_range, - loss_critic_type='MSELoss', - entropy_coef=self.entropy_coef, - critic_coef=self.critic_coef, - normalize_advantage=self.normalize_advantage, - ) + ) \ No newline at end of file