No pyc-files
This commit is contained in:
parent
6789a41730
commit
50c83db6e5
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -214,7 +214,34 @@ class TRL_PG(OnPolicyAlgorithm):
|
|||||||
if self.use_sde:
|
if self.use_sde:
|
||||||
self.policy.reset_noise(self.batch_size)
|
self.policy.reset_noise(self.batch_size)
|
||||||
|
|
||||||
values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)
|
# old code for PPO
|
||||||
|
# values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)
|
||||||
|
|
||||||
|
# src in TRL reference code:
|
||||||
|
# p = self.policy(rollout_data.observations)
|
||||||
|
# proj_p = self.projection(self.policy, p, b_q = (b_old_mean, b_old_std), self._global_step)
|
||||||
|
# new_logpacs = self.policy.log_probability(proj_p, b_actions)
|
||||||
|
# log_prob == new_pogpacs (i think)
|
||||||
|
|
||||||
|
# src of evaluate_actions:
|
||||||
|
# features = self.extract_features(obs)
|
||||||
|
# latent_pi, latent_vf = self.mlp_extractor(features)
|
||||||
|
# distribution = self._get_action_dist_from_latent(latent_pi)
|
||||||
|
# log_prob = distribution.log_prob(actions)
|
||||||
|
# values = self.value_net(latent_vf)
|
||||||
|
# return values, log_prob, distribution.entropy()
|
||||||
|
|
||||||
|
# here we go:
|
||||||
|
pol = self.policy
|
||||||
|
feat = pol.extract_features(rollout_data.observations)
|
||||||
|
latent_pi, latent_vf = pol.mlp_extractor(features)
|
||||||
|
p = pol._get_action_dist_from_latent(latent_pi)
|
||||||
|
proj_p = self.projection(pol, p, b_q, self._global_step) # TODO: define b_q and global_step
|
||||||
|
log_prob = proj_p.log_prob(actions) # or log_prob = pol.log_probability(proj_p, actions)
|
||||||
|
values = self.value_net(latent_vf)
|
||||||
|
entropy = proj_p.entropy() # or not...
|
||||||
|
|
||||||
|
|
||||||
values = values.flatten()
|
values = values.flatten()
|
||||||
# Normalize advantage
|
# Normalize advantage
|
||||||
advantages = rollout_data.advantages
|
advantages = rollout_data.advantages
|
||||||
|
Loading…
Reference in New Issue
Block a user