diff --git a/sb3_trl/trl_pg/trl_pg.py b/sb3_trl/trl_pg/trl_pg.py index d286896..6209877 100644 --- a/sb3_trl/trl_pg/trl_pg.py +++ b/sb3_trl/trl_pg/trl_pg.py @@ -18,6 +18,7 @@ from stable_baselines3.common.vec_env import VecNormalize from ..projections.base_projection_layer import BaseProjectionLayer from ..projections.frob_projection_layer import FrobeniusProjectionLayer +from ..projections.w2_projection_layer import WassersteinProjectionLayer from ..misc.rollout_buffer import GaussianRolloutBuffer, GaussianRolloutBufferSamples @@ -107,7 +108,8 @@ class TRL_PG(OnPolicyAlgorithm): device: Union[th.device, str] = "auto", # Different from PPO: - projection: BaseProjectionLayer = FrobeniusProjectionLayer(), + projection: BaseProjectionLayer = WassersteinProjectionLayer(), + #projection: BaseProjectionLayer = FrobeniusProjectionLayer(), #projection: BaseProjectionLayer = BaseProjectionLayer(), _init_setup_model: bool = True,