Using BaseProjectionLayer as Default
This commit is contained in:
parent
cafc90409f
commit
b9f66dd95d
@ -15,6 +15,9 @@ from stable_baselines3.common.buffers import RolloutBuffer
|
||||
from stable_baselines3.common.callbacks import BaseCallback
|
||||
from stable_baselines3.common.utils import obs_as_tensor
|
||||
|
||||
from ..projections.base_projection_layer import BaseProjectionLayer
|
||||
from ..projections.frob_projection_layer import FrobeniusProjectionLayer
|
||||
|
||||
|
||||
class TRL_PG(OnPolicyAlgorithm):
|
||||
"""
|
||||
@ -63,9 +66,9 @@ class TRL_PG(OnPolicyAlgorithm):
|
||||
:param seed: Seed for the pseudo random generators
|
||||
:param device: Device (cpu, cuda, ...) on which the code should be run.
|
||||
Setting it to auto, the code will be run on the GPU if possible.
|
||||
:param projection: What kind of Projection to use
|
||||
:param _init_setup_model: Whether or not to build the network at the creation of the instance
|
||||
"""
|
||||
# TODO: Add new params to doc
|
||||
|
||||
policy_aliases: Dict[str, Type[BasePolicy]] = {
|
||||
"MlpPolicy": ActorCriticPolicy,
|
||||
@ -100,8 +103,7 @@ class TRL_PG(OnPolicyAlgorithm):
|
||||
device: Union[th.device, str] = "auto",
|
||||
|
||||
# Different from PPO:
|
||||
#projection: BaseProjectionLayer = None,
|
||||
projection=None,
|
||||
projection: BaseProjectionLayer = BaseProjectionLayer,
|
||||
|
||||
_init_setup_model: bool = True,
|
||||
):
|
||||
@ -300,9 +302,9 @@ class TRL_PG(OnPolicyAlgorithm):
|
||||
entropy_losses.append(entropy_loss.item())
|
||||
|
||||
# Difference to PPO: Added trust_region_loss; policy_loss includes entropy_loss + trust_region_loss
|
||||
# trust_region_loss = self.projection.get_trust_region_loss()#TODO: params
|
||||
trust_region_loss = th.zeros(
|
||||
1, device=entropy_loss.device) # TODO: Implement
|
||||
trust_region_loss = self.projection.get_trust_region_loss(
|
||||
pol, p, proj_p)
|
||||
# NOTE to future-self: policy has a different interface then in orig TRL-impl.
|
||||
|
||||
trust_region_losses.append(trust_region_loss.item())
|
||||
|
||||
@ -471,7 +473,7 @@ class TRL_PG(OnPolicyAlgorithm):
|
||||
0]
|
||||
rewards[idx] += self.gamma * terminal_value
|
||||
|
||||
# TODO: calc mean + std
|
||||
# TODO: how to calc mean + std
|
||||
rollout_buffer.add(self._last_obs, actions, rewards,
|
||||
self._last_episode_starts, values, log_probs, mean, std)
|
||||
self._last_obs = new_obs
|
||||
|
Loading…
Reference in New Issue
Block a user