refactor
This commit is contained in:
parent
a3cca71ac9
commit
59060c7533
@ -53,7 +53,7 @@ class PPO(OnPolicy):
|
||||
)
|
||||
|
||||
# Define the policy as a ProbabilisticActor
|
||||
self.policy = ProbabilisticActor(
|
||||
policy = ProbabilisticActor(
|
||||
module=self.ac_module.get_policy_operator(),
|
||||
in_keys=["loc", "scale"],
|
||||
out_keys=["action"],
|
||||
@ -66,8 +66,25 @@ class PPO(OnPolicy):
|
||||
"critic": torch.optim.Adam(self.critic.parameters(), lr=learning_rate)
|
||||
}
|
||||
|
||||
self.adv_module = GAE(
|
||||
gamma=self.gamma,
|
||||
lmbda=self.gae_lambda,
|
||||
value_network=self.critic,
|
||||
average_gae=False,
|
||||
)
|
||||
|
||||
self.loss_module = ClipPPOLoss(
|
||||
actor_network=self.actor,
|
||||
critic_network=self.critic,
|
||||
clip_epsilon=self.clip_range,
|
||||
loss_critic_type='MSELoss',
|
||||
entropy_coef=self.entropy_coef,
|
||||
critic_coef=self.critic_coef,
|
||||
normalize_advantage=self.normalize_advantage,
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
policy=self.policy,
|
||||
policy=policy,
|
||||
env_spec=env_spec,
|
||||
loggers=loggers,
|
||||
optimizers=optimizers,
|
||||
@ -87,21 +104,4 @@ class PPO(OnPolicy):
|
||||
device=device,
|
||||
env_spec_eval=env_spec_eval,
|
||||
eval_episodes=eval_episodes,
|
||||
)
|
||||
|
||||
self.adv_module = GAE(
|
||||
gamma=self.gamma,
|
||||
lmbda=self.gae_lambda,
|
||||
value_network=self.critic,
|
||||
average_gae=False,
|
||||
)
|
||||
|
||||
self.loss_module = ClipPPOLoss(
|
||||
actor_network=self.actor,
|
||||
critic_network=self.critic,
|
||||
clip_epsilon=self.clip_range,
|
||||
loss_critic_type='MSELoss',
|
||||
entropy_coef=self.entropy_coef,
|
||||
critic_coef=self.critic_coef,
|
||||
normalize_advantage=self.normalize_advantage,
|
||||
)
|
||||
)
|
Loading…
Reference in New Issue
Block a user