This commit is contained in:
Dominik Moritz Roth 2024-06-02 11:07:46 +02:00
parent a3cca71ac9
commit 59060c7533

View File

@ -53,7 +53,7 @@ class PPO(OnPolicy):
) )
# Define the policy as a ProbabilisticActor # Define the policy as a ProbabilisticActor
self.policy = ProbabilisticActor( policy = ProbabilisticActor(
module=self.ac_module.get_policy_operator(), module=self.ac_module.get_policy_operator(),
in_keys=["loc", "scale"], in_keys=["loc", "scale"],
out_keys=["action"], out_keys=["action"],
@ -66,8 +66,25 @@ class PPO(OnPolicy):
"critic": torch.optim.Adam(self.critic.parameters(), lr=learning_rate) "critic": torch.optim.Adam(self.critic.parameters(), lr=learning_rate)
} }
self.adv_module = GAE(
gamma=self.gamma,
lmbda=self.gae_lambda,
value_network=self.critic,
average_gae=False,
)
self.loss_module = ClipPPOLoss(
actor_network=self.actor,
critic_network=self.critic,
clip_epsilon=self.clip_range,
loss_critic_type='MSELoss',
entropy_coef=self.entropy_coef,
critic_coef=self.critic_coef,
normalize_advantage=self.normalize_advantage,
)
super().__init__( super().__init__(
policy=self.policy, policy=policy,
env_spec=env_spec, env_spec=env_spec,
loggers=loggers, loggers=loggers,
optimizers=optimizers, optimizers=optimizers,
@ -88,20 +105,3 @@ class PPO(OnPolicy):
env_spec_eval=env_spec_eval, env_spec_eval=env_spec_eval,
eval_episodes=eval_episodes, eval_episodes=eval_episodes,
) )
self.adv_module = GAE(
gamma=self.gamma,
lmbda=self.gae_lambda,
value_network=self.critic,
average_gae=False,
)
self.loss_module = ClipPPOLoss(
actor_network=self.actor,
critic_network=self.critic,
clip_epsilon=self.clip_range,
loss_critic_type='MSELoss',
entropy_coef=self.entropy_coef,
critic_coef=self.critic_coef,
normalize_advantage=self.normalize_advantage,
)