refactor
This commit is contained in:
parent
a3cca71ac9
commit
59060c7533
@ -53,7 +53,7 @@ class PPO(OnPolicy):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Define the policy as a ProbabilisticActor
|
# Define the policy as a ProbabilisticActor
|
||||||
self.policy = ProbabilisticActor(
|
policy = ProbabilisticActor(
|
||||||
module=self.ac_module.get_policy_operator(),
|
module=self.ac_module.get_policy_operator(),
|
||||||
in_keys=["loc", "scale"],
|
in_keys=["loc", "scale"],
|
||||||
out_keys=["action"],
|
out_keys=["action"],
|
||||||
@ -66,8 +66,25 @@ class PPO(OnPolicy):
|
|||||||
"critic": torch.optim.Adam(self.critic.parameters(), lr=learning_rate)
|
"critic": torch.optim.Adam(self.critic.parameters(), lr=learning_rate)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
self.adv_module = GAE(
|
||||||
|
gamma=self.gamma,
|
||||||
|
lmbda=self.gae_lambda,
|
||||||
|
value_network=self.critic,
|
||||||
|
average_gae=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.loss_module = ClipPPOLoss(
|
||||||
|
actor_network=self.actor,
|
||||||
|
critic_network=self.critic,
|
||||||
|
clip_epsilon=self.clip_range,
|
||||||
|
loss_critic_type='MSELoss',
|
||||||
|
entropy_coef=self.entropy_coef,
|
||||||
|
critic_coef=self.critic_coef,
|
||||||
|
normalize_advantage=self.normalize_advantage,
|
||||||
|
)
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
policy=self.policy,
|
policy=policy,
|
||||||
env_spec=env_spec,
|
env_spec=env_spec,
|
||||||
loggers=loggers,
|
loggers=loggers,
|
||||||
optimizers=optimizers,
|
optimizers=optimizers,
|
||||||
@ -88,20 +105,3 @@ class PPO(OnPolicy):
|
|||||||
env_spec_eval=env_spec_eval,
|
env_spec_eval=env_spec_eval,
|
||||||
eval_episodes=eval_episodes,
|
eval_episodes=eval_episodes,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.adv_module = GAE(
|
|
||||||
gamma=self.gamma,
|
|
||||||
lmbda=self.gae_lambda,
|
|
||||||
value_network=self.critic,
|
|
||||||
average_gae=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.loss_module = ClipPPOLoss(
|
|
||||||
actor_network=self.actor,
|
|
||||||
critic_network=self.critic,
|
|
||||||
clip_epsilon=self.clip_range,
|
|
||||||
loss_critic_type='MSELoss',
|
|
||||||
entropy_coef=self.entropy_coef,
|
|
||||||
critic_coef=self.critic_coef,
|
|
||||||
normalize_advantage=self.normalize_advantage,
|
|
||||||
)
|
|
||||||
|
Loading…
Reference in New Issue
Block a user