From e8d423f91ff69bea51b890295006b4d846a0651e Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Wed, 29 Jun 2022 12:46:37 +0200 Subject: [PATCH] Testing the new WassersteinProjectionLayer --- sb3_trl/trl_pg/trl_pg.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sb3_trl/trl_pg/trl_pg.py b/sb3_trl/trl_pg/trl_pg.py index d286896..6209877 100644 --- a/sb3_trl/trl_pg/trl_pg.py +++ b/sb3_trl/trl_pg/trl_pg.py @@ -18,6 +18,7 @@ from stable_baselines3.common.vec_env import VecNormalize from ..projections.base_projection_layer import BaseProjectionLayer from ..projections.frob_projection_layer import FrobeniusProjectionLayer +from ..projections.w2_projection_layer import WassersteinProjectionLayer from ..misc.rollout_buffer import GaussianRolloutBuffer, GaussianRolloutBufferSamples @@ -107,7 +108,8 @@ class TRL_PG(OnPolicyAlgorithm): device: Union[th.device, str] = "auto", # Different from PPO: - projection: BaseProjectionLayer = FrobeniusProjectionLayer(), + projection: BaseProjectionLayer = WassersteinProjectionLayer(), + #projection: BaseProjectionLayer = FrobeniusProjectionLayer(), #projection: BaseProjectionLayer = BaseProjectionLayer(), _init_setup_model: bool = True,