diff --git a/sb3_trl/__pycache__/__init__.cpython-310.pyc b/sb3_trl/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 61a5b16..0000000 Binary files a/sb3_trl/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/sb3_trl/trl_pg/__pycache__/__init__.cpython-310.pyc b/sb3_trl/trl_pg/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 6f7ab44..0000000 Binary files a/sb3_trl/trl_pg/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/sb3_trl/trl_pg/__pycache__/policies.cpython-310.pyc b/sb3_trl/trl_pg/__pycache__/policies.cpython-310.pyc deleted file mode 100644 index 45e52dd..0000000 Binary files a/sb3_trl/trl_pg/__pycache__/policies.cpython-310.pyc and /dev/null differ diff --git a/sb3_trl/trl_pg/__pycache__/trl_pg.cpython-310.pyc b/sb3_trl/trl_pg/__pycache__/trl_pg.cpython-310.pyc deleted file mode 100644 index 37e5875..0000000 Binary files a/sb3_trl/trl_pg/__pycache__/trl_pg.cpython-310.pyc and /dev/null differ diff --git a/sb3_trl/trl_pg/trl_pg.py b/sb3_trl/trl_pg/trl_pg.py index 1d813f5..41405b7 100644 --- a/sb3_trl/trl_pg/trl_pg.py +++ b/sb3_trl/trl_pg/trl_pg.py @@ -214,7 +214,34 @@ class TRL_PG(OnPolicyAlgorithm): if self.use_sde: self.policy.reset_noise(self.batch_size) - values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions) + # old code for PPO + # values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions) + + # src in TRL reference code: + # p = self.policy(rollout_data.observations) + # proj_p = self.projection(self.policy, p, b_q = (b_old_mean, b_old_std), self._global_step) + # new_logpacs = self.policy.log_probability(proj_p, b_actions) + # log_prob == new_pogpacs (i think) + + # src of evaluate_actions: + # features = self.extract_features(obs) + # latent_pi, latent_vf = self.mlp_extractor(features) + # distribution = self._get_action_dist_from_latent(latent_pi) + # log_prob = distribution.log_prob(actions) + # values = self.value_net(latent_vf) + # return values, log_prob, distribution.entropy() + + # here we go: + pol = self.policy + feat = pol.extract_features(rollout_data.observations) + latent_pi, latent_vf = pol.mlp_extractor(features) + p = pol._get_action_dist_from_latent(latent_pi) + proj_p = self.projection(pol, p, b_q, self._global_step) # TODO: define b_q and global_step + log_prob = proj_p.log_prob(actions) # or log_prob = pol.log_probability(proj_p, actions) + values = self.value_net(latent_vf) + entropy = proj_p.entropy() # or not... + + values = values.flatten() # Normalize advantage advantages = rollout_data.advantages