mean and std are now saved to the rollout
This commit is contained in:
parent
df21e1dc3f
commit
cf5a2e82fc
@ -434,6 +434,11 @@ class TRL_PG(OnPolicyAlgorithm):
|
||||
# Convert to pytorch tensor or to TensorDict
|
||||
obs_tensor = obs_as_tensor(self._last_obs, self.device)
|
||||
actions, values, log_probs = self.policy(obs_tensor)
|
||||
dist = self.policy.get_distribution(obs_tensor)
|
||||
# TODO: Enforce this requirement somwhere else...
|
||||
assert isinstance(
|
||||
dist, th.distributions.Normal), 'TRL is only implemented for Policys in a continuous action-space that is gauss-parametarized!'
|
||||
mean, std = dist.mean, dist.stddev
|
||||
actions = actions.cpu().numpy()
|
||||
|
||||
# Rescale and perform action
|
||||
@ -474,7 +479,6 @@ class TRL_PG(OnPolicyAlgorithm):
|
||||
0]
|
||||
rewards[idx] += self.gamma * terminal_value
|
||||
|
||||
# TODO: how to calc mean + std
|
||||
rollout_buffer.add(self._last_obs, actions, rewards,
|
||||
self._last_episode_starts, values, log_probs, mean, std)
|
||||
self._last_obs = new_obs
|
||||
|
Loading…
Reference in New Issue
Block a user