Using BaseProjectionLayer as Default

This commit is contained in:
Dominik Moritz Roth 2022-06-25 15:33:43 +02:00
parent cafc90409f
commit b9f66dd95d

View File

@ -15,6 +15,9 @@ from stable_baselines3.common.buffers import RolloutBuffer
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.utils import obs_as_tensor
from ..projections.base_projection_layer import BaseProjectionLayer
from ..projections.frob_projection_layer import FrobeniusProjectionLayer
class TRL_PG(OnPolicyAlgorithm):
"""
@ -63,9 +66,9 @@ class TRL_PG(OnPolicyAlgorithm):
:param seed: Seed for the pseudo random generators
:param device: Device (cpu, cuda, ...) on which the code should be run.
Setting it to auto, the code will be run on the GPU if possible.
:param projection: What kind of Projection to use
:param _init_setup_model: Whether or not to build the network at the creation of the instance
"""
# TODO: Add new params to doc
policy_aliases: Dict[str, Type[BasePolicy]] = {
"MlpPolicy": ActorCriticPolicy,
@ -100,8 +103,7 @@ class TRL_PG(OnPolicyAlgorithm):
device: Union[th.device, str] = "auto",
# Different from PPO:
#projection: BaseProjectionLayer = None,
projection=None,
projection: BaseProjectionLayer = BaseProjectionLayer,
_init_setup_model: bool = True,
):
@ -300,9 +302,9 @@ class TRL_PG(OnPolicyAlgorithm):
entropy_losses.append(entropy_loss.item())
# Difference to PPO: Added trust_region_loss; policy_loss includes entropy_loss + trust_region_loss
# trust_region_loss = self.projection.get_trust_region_loss()#TODO: params
trust_region_loss = th.zeros(
1, device=entropy_loss.device) # TODO: Implement
trust_region_loss = self.projection.get_trust_region_loss(
pol, p, proj_p)
# NOTE to future-self: policy has a different interface then in orig TRL-impl.
trust_region_losses.append(trust_region_loss.item())
@ -471,7 +473,7 @@ class TRL_PG(OnPolicyAlgorithm):
0]
rewards[idx] += self.gamma * terminal_value
# TODO: calc mean + std
# TODO: how to calc mean + std
rollout_buffer.add(self._last_obs, actions, rewards,
self._last_episode_starts, values, log_probs, mean, std)
self._last_obs = new_obs