diff --git a/sb3_trl/trl_pg/trl_pg.py b/sb3_trl/trl_pg/trl_pg.py index d217f89..458c4e7 100644 --- a/sb3_trl/trl_pg/trl_pg.py +++ b/sb3_trl/trl_pg/trl_pg.py @@ -434,6 +434,11 @@ class TRL_PG(OnPolicyAlgorithm): # Convert to pytorch tensor or to TensorDict obs_tensor = obs_as_tensor(self._last_obs, self.device) actions, values, log_probs = self.policy(obs_tensor) + dist = self.policy.get_distribution(obs_tensor) + # TODO: Enforce this requirement somwhere else... + assert isinstance( + dist, th.distributions.Normal), 'TRL is only implemented for Policys in a continuous action-space that is gauss-parametarized!' + mean, std = dist.mean, dist.stddev actions = actions.cpu().numpy() # Rescale and perform action @@ -474,7 +479,6 @@ class TRL_PG(OnPolicyAlgorithm): 0] rewards[idx] += self.gamma * terminal_value - # TODO: how to calc mean + std rollout_buffer.add(self._last_obs, actions, rewards, self._last_episode_starts, values, log_probs, mean, std) self._last_obs = new_obs