Testing the new WassersteinProjectionLayer

This commit is contained in:
Dominik Moritz Roth 2022-06-29 12:46:37 +02:00
parent 4e77190d8e
commit e8d423f91f

View File

@ -18,6 +18,7 @@ from stable_baselines3.common.vec_env import VecNormalize
from ..projections.base_projection_layer import BaseProjectionLayer
from ..projections.frob_projection_layer import FrobeniusProjectionLayer
from ..projections.w2_projection_layer import WassersteinProjectionLayer
from ..misc.rollout_buffer import GaussianRolloutBuffer, GaussianRolloutBufferSamples
@ -107,7 +108,8 @@ class TRL_PG(OnPolicyAlgorithm):
device: Union[th.device, str] = "auto",
# Different from PPO:
projection: BaseProjectionLayer = FrobeniusProjectionLayer(),
projection: BaseProjectionLayer = WassersteinProjectionLayer(),
#projection: BaseProjectionLayer = FrobeniusProjectionLayer(),
#projection: BaseProjectionLayer = BaseProjectionLayer(),
_init_setup_model: bool = True,