_global_steps ctr added
This commit is contained in:
parent
866f863d70
commit
cafc90409f
@ -167,6 +167,7 @@ class TRL_PG(OnPolicyAlgorithm):
|
|||||||
|
|
||||||
# Different from PPO:
|
# Different from PPO:
|
||||||
self.projection = projection
|
self.projection = projection
|
||||||
|
self._global_steps = 0
|
||||||
|
|
||||||
if _init_setup_model:
|
if _init_setup_model:
|
||||||
self._setup_model()
|
self._setup_model()
|
||||||
@ -210,6 +211,10 @@ class TRL_PG(OnPolicyAlgorithm):
|
|||||||
approx_kl_divs = []
|
approx_kl_divs = []
|
||||||
# Do a complete pass on the rollout buffer
|
# Do a complete pass on the rollout buffer
|
||||||
for rollout_data in self.rollout_buffer.get(self.batch_size):
|
for rollout_data in self.rollout_buffer.get(self.batch_size):
|
||||||
|
# This is new compared to PPO.
|
||||||
|
# Calculating the TR-Projections we need to know the step number
|
||||||
|
self._global_steps += 1
|
||||||
|
|
||||||
actions = rollout_data.actions
|
actions = rollout_data.actions
|
||||||
if isinstance(self.action_space, spaces.Discrete):
|
if isinstance(self.action_space, spaces.Discrete):
|
||||||
# Convert discrete action from float to long
|
# Convert discrete action from float to long
|
||||||
@ -241,7 +246,7 @@ class TRL_PG(OnPolicyAlgorithm):
|
|||||||
features = pol.extract_features(rollout_data.observations)
|
features = pol.extract_features(rollout_data.observations)
|
||||||
latent_pi, latent_vf = pol.mlp_extractor(features)
|
latent_pi, latent_vf = pol.mlp_extractor(features)
|
||||||
p = pol._get_action_dist_from_latent(latent_pi)
|
p = pol._get_action_dist_from_latent(latent_pi)
|
||||||
# TODO: define b_q and global_step
|
b_q = rollout_data.mean, rollout_data.std
|
||||||
proj_p = self.projection(pol, p, b_q, self._global_step)
|
proj_p = self.projection(pol, p, b_q, self._global_step)
|
||||||
log_prob = proj_p.log_prob(actions)
|
log_prob = proj_p.log_prob(actions)
|
||||||
# or log_prob = pol.log_probability(proj_p, actions)
|
# or log_prob = pol.log_probability(proj_p, actions)
|
||||||
@ -466,8 +471,9 @@ class TRL_PG(OnPolicyAlgorithm):
|
|||||||
0]
|
0]
|
||||||
rewards[idx] += self.gamma * terminal_value
|
rewards[idx] += self.gamma * terminal_value
|
||||||
|
|
||||||
|
# TODO: calc mean + std
|
||||||
rollout_buffer.add(self._last_obs, actions, rewards,
|
rollout_buffer.add(self._last_obs, actions, rewards,
|
||||||
self._last_episode_starts, values, log_probs)
|
self._last_episode_starts, values, log_probs, mean, std)
|
||||||
self._last_obs = new_obs
|
self._last_obs = new_obs
|
||||||
self._last_episode_starts = dones
|
self._last_episode_starts = dones
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user