Testing the new WassersteinProjectionLayer
This commit is contained in:
parent
4e77190d8e
commit
e8d423f91f
@ -18,6 +18,7 @@ from stable_baselines3.common.vec_env import VecNormalize
|
||||
|
||||
from ..projections.base_projection_layer import BaseProjectionLayer
|
||||
from ..projections.frob_projection_layer import FrobeniusProjectionLayer
|
||||
from ..projections.w2_projection_layer import WassersteinProjectionLayer
|
||||
|
||||
from ..misc.rollout_buffer import GaussianRolloutBuffer, GaussianRolloutBufferSamples
|
||||
|
||||
@ -107,7 +108,8 @@ class TRL_PG(OnPolicyAlgorithm):
|
||||
device: Union[th.device, str] = "auto",
|
||||
|
||||
# Different from PPO:
|
||||
projection: BaseProjectionLayer = FrobeniusProjectionLayer(),
|
||||
projection: BaseProjectionLayer = WassersteinProjectionLayer(),
|
||||
#projection: BaseProjectionLayer = FrobeniusProjectionLayer(),
|
||||
#projection: BaseProjectionLayer = BaseProjectionLayer(),
|
||||
|
||||
_init_setup_model: bool = True,
|
||||
|
Loading…
Reference in New Issue
Block a user