No pyc-files

This commit is contained in:
Dominik Moritz Roth 2022-06-17 15:59:43 +02:00
parent 6789a41730
commit 50c83db6e5
5 changed files with 28 additions and 1 deletions

View File

@ -214,7 +214,34 @@ class TRL_PG(OnPolicyAlgorithm):
if self.use_sde:
self.policy.reset_noise(self.batch_size)
values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)
# old code for PPO
# values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)
# src in TRL reference code:
# p = self.policy(rollout_data.observations)
# proj_p = self.projection(self.policy, p, b_q = (b_old_mean, b_old_std), self._global_step)
# new_logpacs = self.policy.log_probability(proj_p, b_actions)
# log_prob == new_pogpacs (i think)
# src of evaluate_actions:
# features = self.extract_features(obs)
# latent_pi, latent_vf = self.mlp_extractor(features)
# distribution = self._get_action_dist_from_latent(latent_pi)
# log_prob = distribution.log_prob(actions)
# values = self.value_net(latent_vf)
# return values, log_prob, distribution.entropy()
# here we go:
pol = self.policy
feat = pol.extract_features(rollout_data.observations)
latent_pi, latent_vf = pol.mlp_extractor(features)
p = pol._get_action_dist_from_latent(latent_pi)
proj_p = self.projection(pol, p, b_q, self._global_step) # TODO: define b_q and global_step
log_prob = proj_p.log_prob(actions) # or log_prob = pol.log_probability(proj_p, actions)
values = self.value_net(latent_vf)
entropy = proj_p.entropy() # or not...
values = values.flatten()
# Normalize advantage
advantages = rollout_data.advantages