Using BaseProjectionLayer as Default
This commit is contained in:
parent
cafc90409f
commit
b9f66dd95d
@ -15,6 +15,9 @@ from stable_baselines3.common.buffers import RolloutBuffer
|
|||||||
from stable_baselines3.common.callbacks import BaseCallback
|
from stable_baselines3.common.callbacks import BaseCallback
|
||||||
from stable_baselines3.common.utils import obs_as_tensor
|
from stable_baselines3.common.utils import obs_as_tensor
|
||||||
|
|
||||||
|
from ..projections.base_projection_layer import BaseProjectionLayer
|
||||||
|
from ..projections.frob_projection_layer import FrobeniusProjectionLayer
|
||||||
|
|
||||||
|
|
||||||
class TRL_PG(OnPolicyAlgorithm):
|
class TRL_PG(OnPolicyAlgorithm):
|
||||||
"""
|
"""
|
||||||
@ -63,9 +66,9 @@ class TRL_PG(OnPolicyAlgorithm):
|
|||||||
:param seed: Seed for the pseudo random generators
|
:param seed: Seed for the pseudo random generators
|
||||||
:param device: Device (cpu, cuda, ...) on which the code should be run.
|
:param device: Device (cpu, cuda, ...) on which the code should be run.
|
||||||
Setting it to auto, the code will be run on the GPU if possible.
|
Setting it to auto, the code will be run on the GPU if possible.
|
||||||
|
:param projection: What kind of Projection to use
|
||||||
:param _init_setup_model: Whether or not to build the network at the creation of the instance
|
:param _init_setup_model: Whether or not to build the network at the creation of the instance
|
||||||
"""
|
"""
|
||||||
# TODO: Add new params to doc
|
|
||||||
|
|
||||||
policy_aliases: Dict[str, Type[BasePolicy]] = {
|
policy_aliases: Dict[str, Type[BasePolicy]] = {
|
||||||
"MlpPolicy": ActorCriticPolicy,
|
"MlpPolicy": ActorCriticPolicy,
|
||||||
@ -100,8 +103,7 @@ class TRL_PG(OnPolicyAlgorithm):
|
|||||||
device: Union[th.device, str] = "auto",
|
device: Union[th.device, str] = "auto",
|
||||||
|
|
||||||
# Different from PPO:
|
# Different from PPO:
|
||||||
#projection: BaseProjectionLayer = None,
|
projection: BaseProjectionLayer = BaseProjectionLayer,
|
||||||
projection=None,
|
|
||||||
|
|
||||||
_init_setup_model: bool = True,
|
_init_setup_model: bool = True,
|
||||||
):
|
):
|
||||||
@ -300,9 +302,9 @@ class TRL_PG(OnPolicyAlgorithm):
|
|||||||
entropy_losses.append(entropy_loss.item())
|
entropy_losses.append(entropy_loss.item())
|
||||||
|
|
||||||
# Difference to PPO: Added trust_region_loss; policy_loss includes entropy_loss + trust_region_loss
|
# Difference to PPO: Added trust_region_loss; policy_loss includes entropy_loss + trust_region_loss
|
||||||
# trust_region_loss = self.projection.get_trust_region_loss()#TODO: params
|
trust_region_loss = self.projection.get_trust_region_loss(
|
||||||
trust_region_loss = th.zeros(
|
pol, p, proj_p)
|
||||||
1, device=entropy_loss.device) # TODO: Implement
|
# NOTE to future-self: policy has a different interface then in orig TRL-impl.
|
||||||
|
|
||||||
trust_region_losses.append(trust_region_loss.item())
|
trust_region_losses.append(trust_region_loss.item())
|
||||||
|
|
||||||
@ -471,7 +473,7 @@ class TRL_PG(OnPolicyAlgorithm):
|
|||||||
0]
|
0]
|
||||||
rewards[idx] += self.gamma * terminal_value
|
rewards[idx] += self.gamma * terminal_value
|
||||||
|
|
||||||
# TODO: calc mean + std
|
# TODO: how to calc mean + std
|
||||||
rollout_buffer.add(self._last_obs, actions, rewards,
|
rollout_buffer.add(self._last_obs, actions, rewards,
|
||||||
self._last_episode_starts, values, log_probs, mean, std)
|
self._last_episode_starts, values, log_probs, mean, std)
|
||||||
self._last_obs = new_obs
|
self._last_obs = new_obs
|
||||||
|
Loading…
Reference in New Issue
Block a user