mean and std are now saved to the rollout

This commit is contained in:
Dominik Moritz Roth 2022-06-25 18:29:55 +02:00
parent df21e1dc3f
commit cf5a2e82fc

View File

@ -434,6 +434,11 @@ class TRL_PG(OnPolicyAlgorithm):
# Convert to pytorch tensor or to TensorDict
obs_tensor = obs_as_tensor(self._last_obs, self.device)
actions, values, log_probs = self.policy(obs_tensor)
dist = self.policy.get_distribution(obs_tensor)
# TODO: Enforce this requirement somwhere else...
assert isinstance(
dist, th.distributions.Normal), 'TRL is only implemented for Policys in a continuous action-space that is gauss-parametarized!'
mean, std = dist.mean, dist.stddev
actions = actions.cpu().numpy()
# Rescale and perform action
@ -474,7 +479,6 @@ class TRL_PG(OnPolicyAlgorithm):
0]
rewards[idx] += self.gamma * terminal_value
# TODO: how to calc mean + std
rollout_buffer.add(self._last_obs, actions, rewards,
self._last_episode_starts, values, log_probs, mean, std)
self._last_obs = new_obs