mean and std are now saved to the rollout
This commit is contained in:
parent
df21e1dc3f
commit
cf5a2e82fc
@ -434,6 +434,11 @@ class TRL_PG(OnPolicyAlgorithm):
|
|||||||
# Convert to pytorch tensor or to TensorDict
|
# Convert to pytorch tensor or to TensorDict
|
||||||
obs_tensor = obs_as_tensor(self._last_obs, self.device)
|
obs_tensor = obs_as_tensor(self._last_obs, self.device)
|
||||||
actions, values, log_probs = self.policy(obs_tensor)
|
actions, values, log_probs = self.policy(obs_tensor)
|
||||||
|
dist = self.policy.get_distribution(obs_tensor)
|
||||||
|
# TODO: Enforce this requirement somwhere else...
|
||||||
|
assert isinstance(
|
||||||
|
dist, th.distributions.Normal), 'TRL is only implemented for Policys in a continuous action-space that is gauss-parametarized!'
|
||||||
|
mean, std = dist.mean, dist.stddev
|
||||||
actions = actions.cpu().numpy()
|
actions = actions.cpu().numpy()
|
||||||
|
|
||||||
# Rescale and perform action
|
# Rescale and perform action
|
||||||
@ -474,7 +479,6 @@ class TRL_PG(OnPolicyAlgorithm):
|
|||||||
0]
|
0]
|
||||||
rewards[idx] += self.gamma * terminal_value
|
rewards[idx] += self.gamma * terminal_value
|
||||||
|
|
||||||
# TODO: how to calc mean + std
|
|
||||||
rollout_buffer.add(self._last_obs, actions, rewards,
|
rollout_buffer.add(self._last_obs, actions, rewards,
|
||||||
self._last_episode_starts, values, log_probs, mean, std)
|
self._last_episode_starts, values, log_probs, mean, std)
|
||||||
self._last_obs = new_obs
|
self._last_obs = new_obs
|
||||||
|
Loading…
Reference in New Issue
Block a user