Reorder...

This commit is contained in:
Dominik Moritz Roth 2024-06-02 12:09:26 +02:00
parent d51bf948d4
commit a867a74138

View File

@ -66,23 +66,6 @@ class PPO(OnPolicy):
"critic": torch.optim.Adam(self.critic.parameters(), lr=learning_rate)
}
self.adv_module = GAE(
gamma=self.gamma,
lmbda=self.gae_lambda,
value_network=self.critic,
average_gae=False,
)
self.loss_module = ClipPPOLoss(
actor_network=self.actor,
critic_network=self.critic,
clip_epsilon=self.clip_range,
loss_critic_type='MSELoss',
entropy_coef=self.entropy_coef,
critic_coef=self.critic_coef,
normalize_advantage=self.normalize_advantage,
)
super().__init__(
policy=policy,
env_spec=env_spec,
@ -105,3 +88,20 @@ class PPO(OnPolicy):
env_spec_eval=env_spec_eval,
eval_episodes=eval_episodes,
)
self.adv_module = GAE(
gamma=self.gamma,
lmbda=self.gae_lambda,
value_network=self.critic,
average_gae=False,
)
self.loss_module = ClipPPOLoss(
actor_network=self.actor,
critic_network=self.critic,
clip_epsilon=self.clip_range,
loss_critic_type='MSELoss',
entropy_coef=self.entropy_coef,
critic_coef=self.critic_coef,
normalize_advantage=self.normalize_advantage,
)