Reorder...
This commit is contained in:
parent
d51bf948d4
commit
a867a74138
@ -66,23 +66,6 @@ class PPO(OnPolicy):
|
|||||||
"critic": torch.optim.Adam(self.critic.parameters(), lr=learning_rate)
|
"critic": torch.optim.Adam(self.critic.parameters(), lr=learning_rate)
|
||||||
}
|
}
|
||||||
|
|
||||||
self.adv_module = GAE(
|
|
||||||
gamma=self.gamma,
|
|
||||||
lmbda=self.gae_lambda,
|
|
||||||
value_network=self.critic,
|
|
||||||
average_gae=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.loss_module = ClipPPOLoss(
|
|
||||||
actor_network=self.actor,
|
|
||||||
critic_network=self.critic,
|
|
||||||
clip_epsilon=self.clip_range,
|
|
||||||
loss_critic_type='MSELoss',
|
|
||||||
entropy_coef=self.entropy_coef,
|
|
||||||
critic_coef=self.critic_coef,
|
|
||||||
normalize_advantage=self.normalize_advantage,
|
|
||||||
)
|
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
policy=policy,
|
policy=policy,
|
||||||
env_spec=env_spec,
|
env_spec=env_spec,
|
||||||
@ -105,3 +88,20 @@ class PPO(OnPolicy):
|
|||||||
env_spec_eval=env_spec_eval,
|
env_spec_eval=env_spec_eval,
|
||||||
eval_episodes=eval_episodes,
|
eval_episodes=eval_episodes,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.adv_module = GAE(
|
||||||
|
gamma=self.gamma,
|
||||||
|
lmbda=self.gae_lambda,
|
||||||
|
value_network=self.critic,
|
||||||
|
average_gae=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.loss_module = ClipPPOLoss(
|
||||||
|
actor_network=self.actor,
|
||||||
|
critic_network=self.critic,
|
||||||
|
clip_epsilon=self.clip_range,
|
||||||
|
loss_critic_type='MSELoss',
|
||||||
|
entropy_coef=self.entropy_coef,
|
||||||
|
critic_coef=self.critic_coef,
|
||||||
|
normalize_advantage=self.normalize_advantage,
|
||||||
|
)
|
Loading…
Reference in New Issue
Block a user