Using BaseProjectionLayer as Default

This commit is contained in:
Dominik Moritz Roth 2022-06-25 15:33:43 +02:00
parent cafc90409f
commit b9f66dd95d

View File

@ -15,6 +15,9 @@ from stable_baselines3.common.buffers import RolloutBuffer
from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.utils import obs_as_tensor from stable_baselines3.common.utils import obs_as_tensor
from ..projections.base_projection_layer import BaseProjectionLayer
from ..projections.frob_projection_layer import FrobeniusProjectionLayer
class TRL_PG(OnPolicyAlgorithm): class TRL_PG(OnPolicyAlgorithm):
""" """
@ -63,9 +66,9 @@ class TRL_PG(OnPolicyAlgorithm):
:param seed: Seed for the pseudo random generators :param seed: Seed for the pseudo random generators
:param device: Device (cpu, cuda, ...) on which the code should be run. :param device: Device (cpu, cuda, ...) on which the code should be run.
Setting it to auto, the code will be run on the GPU if possible. Setting it to auto, the code will be run on the GPU if possible.
:param projection: What kind of Projection to use
:param _init_setup_model: Whether or not to build the network at the creation of the instance :param _init_setup_model: Whether or not to build the network at the creation of the instance
""" """
# TODO: Add new params to doc
policy_aliases: Dict[str, Type[BasePolicy]] = { policy_aliases: Dict[str, Type[BasePolicy]] = {
"MlpPolicy": ActorCriticPolicy, "MlpPolicy": ActorCriticPolicy,
@ -100,8 +103,7 @@ class TRL_PG(OnPolicyAlgorithm):
device: Union[th.device, str] = "auto", device: Union[th.device, str] = "auto",
# Different from PPO: # Different from PPO:
#projection: BaseProjectionLayer = None, projection: BaseProjectionLayer = BaseProjectionLayer,
projection=None,
_init_setup_model: bool = True, _init_setup_model: bool = True,
): ):
@ -300,9 +302,9 @@ class TRL_PG(OnPolicyAlgorithm):
entropy_losses.append(entropy_loss.item()) entropy_losses.append(entropy_loss.item())
# Difference to PPO: Added trust_region_loss; policy_loss includes entropy_loss + trust_region_loss # Difference to PPO: Added trust_region_loss; policy_loss includes entropy_loss + trust_region_loss
# trust_region_loss = self.projection.get_trust_region_loss()#TODO: params trust_region_loss = self.projection.get_trust_region_loss(
trust_region_loss = th.zeros( pol, p, proj_p)
1, device=entropy_loss.device) # TODO: Implement # NOTE to future-self: policy has a different interface then in orig TRL-impl.
trust_region_losses.append(trust_region_loss.item()) trust_region_losses.append(trust_region_loss.item())
@ -471,7 +473,7 @@ class TRL_PG(OnPolicyAlgorithm):
0] 0]
rewards[idx] += self.gamma * terminal_value rewards[idx] += self.gamma * terminal_value
# TODO: calc mean + std # TODO: how to calc mean + std
rollout_buffer.add(self._last_obs, actions, rewards, rollout_buffer.add(self._last_obs, actions, rewards,
self._last_episode_starts, values, log_probs, mean, std) self._last_episode_starts, values, log_probs, mean, std)
self._last_obs = new_obs self._last_obs = new_obs