Testing the new WassersteinProjectionLayer
This commit is contained in:
parent
4e77190d8e
commit
e8d423f91f
@ -18,6 +18,7 @@ from stable_baselines3.common.vec_env import VecNormalize
|
|||||||
|
|
||||||
from ..projections.base_projection_layer import BaseProjectionLayer
|
from ..projections.base_projection_layer import BaseProjectionLayer
|
||||||
from ..projections.frob_projection_layer import FrobeniusProjectionLayer
|
from ..projections.frob_projection_layer import FrobeniusProjectionLayer
|
||||||
|
from ..projections.w2_projection_layer import WassersteinProjectionLayer
|
||||||
|
|
||||||
from ..misc.rollout_buffer import GaussianRolloutBuffer, GaussianRolloutBufferSamples
|
from ..misc.rollout_buffer import GaussianRolloutBuffer, GaussianRolloutBufferSamples
|
||||||
|
|
||||||
@ -107,7 +108,8 @@ class TRL_PG(OnPolicyAlgorithm):
|
|||||||
device: Union[th.device, str] = "auto",
|
device: Union[th.device, str] = "auto",
|
||||||
|
|
||||||
# Different from PPO:
|
# Different from PPO:
|
||||||
projection: BaseProjectionLayer = FrobeniusProjectionLayer(),
|
projection: BaseProjectionLayer = WassersteinProjectionLayer(),
|
||||||
|
#projection: BaseProjectionLayer = FrobeniusProjectionLayer(),
|
||||||
#projection: BaseProjectionLayer = BaseProjectionLayer(),
|
#projection: BaseProjectionLayer = BaseProjectionLayer(),
|
||||||
|
|
||||||
_init_setup_model: bool = True,
|
_init_setup_model: bool = True,
|
||||||
|
Loading…
Reference in New Issue
Block a user