diff --git a/sb3_trl/trl_pg/trl_pg.py b/sb3_trl/trl_pg/trl_pg.py index 37d86b2..ef1b62c 100644 --- a/sb3_trl/trl_pg/trl_pg.py +++ b/sb3_trl/trl_pg/trl_pg.py @@ -15,6 +15,9 @@ from stable_baselines3.common.buffers import RolloutBuffer from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.utils import obs_as_tensor +from ..projections.base_projection_layer import BaseProjectionLayer +from ..projections.frob_projection_layer import FrobeniusProjectionLayer + class TRL_PG(OnPolicyAlgorithm): """ @@ -63,9 +66,9 @@ class TRL_PG(OnPolicyAlgorithm): :param seed: Seed for the pseudo random generators :param device: Device (cpu, cuda, ...) on which the code should be run. Setting it to auto, the code will be run on the GPU if possible. + :param projection: What kind of Projection to use :param _init_setup_model: Whether or not to build the network at the creation of the instance """ - # TODO: Add new params to doc policy_aliases: Dict[str, Type[BasePolicy]] = { "MlpPolicy": ActorCriticPolicy, @@ -100,8 +103,7 @@ class TRL_PG(OnPolicyAlgorithm): device: Union[th.device, str] = "auto", # Different from PPO: - #projection: BaseProjectionLayer = None, - projection=None, + projection: BaseProjectionLayer = BaseProjectionLayer, _init_setup_model: bool = True, ): @@ -300,9 +302,9 @@ class TRL_PG(OnPolicyAlgorithm): entropy_losses.append(entropy_loss.item()) # Difference to PPO: Added trust_region_loss; policy_loss includes entropy_loss + trust_region_loss - # trust_region_loss = self.projection.get_trust_region_loss()#TODO: params - trust_region_loss = th.zeros( - 1, device=entropy_loss.device) # TODO: Implement + trust_region_loss = self.projection.get_trust_region_loss( + pol, p, proj_p) + # NOTE to future-self: policy has a different interface then in orig TRL-impl. trust_region_losses.append(trust_region_loss.item()) @@ -471,7 +473,7 @@ class TRL_PG(OnPolicyAlgorithm): 0] rewards[idx] += self.gamma * terminal_value - # TODO: calc mean + std + # TODO: how to calc mean + std rollout_buffer.add(self._last_obs, actions, rewards, self._last_episode_starts, values, log_probs, mean, std) self._last_obs = new_obs