diff --git a/metastable_baselines/ppo/policies.py b/metastable_baselines/ppo/policies.py index 6377eb9..5d03868 100644 --- a/metastable_baselines/ppo/policies.py +++ b/metastable_baselines/ppo/policies.py @@ -37,7 +37,7 @@ from stable_baselines3.common.torch_layers import ( from stable_baselines3.common.preprocessing import get_action_dim -from metastable_baselines.projections.w2_projection_layer import WassersteinProjectionLayer +from metastable_projections.projections.w2_projection_layer import WassersteinProjectionLayer from ..distributions import UniversalGaussianDistribution, make_proba_distribution from ..misc.distTools import get_mean_and_chol diff --git a/metastable_baselines/ppo/ppo.py b/metastable_baselines/ppo/ppo.py index 07aca36..bfb7b3a 100644 --- a/metastable_baselines/ppo/ppo.py +++ b/metastable_baselines/ppo/ppo.py @@ -18,10 +18,10 @@ from stable_baselines3.common.vec_env import VecNormalize from ..misc.distTools import new_dist_like, new_dist_like_from_sqrt -from ..projections.base_projection_layer import BaseProjectionLayer -from ..projections.frob_projection_layer import FrobeniusProjectionLayer -from ..projections.w2_projection_layer import WassersteinProjectionLayer -from ..projections.kl_projection_layer import KLProjectionLayer +from metastable_projections.projections.base_projection_layer import BaseProjectionLayer +from metastable_projections.projections.frob_projection_layer import FrobeniusProjectionLayer +from metastable_projections.projections.w2_projection_layer import WassersteinProjectionLayer +from metastable_projections.projections.kl_projection_layer import KLProjectionLayer from ..misc.rollout_buffer import GaussianRolloutCollectorAuxclass diff --git a/metastable_baselines/sac/sac.py b/metastable_baselines/sac/sac.py index 6244ef1..6206dd3 100644 --- a/metastable_baselines/sac/sac.py +++ b/metastable_baselines/sac/sac.py @@ -16,10 +16,10 @@ from metastable_baselines.sac.policies import CnnPolicy, MlpPolicy, MultiInputPo from ..misc.distTools import new_dist_like -from ..projections.base_projection_layer import BaseProjectionLayer -from ..projections.frob_projection_layer import FrobeniusProjectionLayer -from ..projections.w2_projection_layer import WassersteinProjectionLayer -from ..projections.kl_projection_layer import KLProjectionLayer +from metastable_projections.projections.base_projection_layer import BaseProjectionLayer +from metastable_projections.projections.frob_projection_layer import FrobeniusProjectionLayer +from metastable_projections.projections.w2_projection_layer import WassersteinProjectionLayer +from metastable_projections.projections.kl_projection_layer import KLProjectionLayer from ..misc.rollout_buffer import GaussianRolloutCollectorAuxclass diff --git a/test.py b/test.py index 12e3936..174fad9 100755 --- a/test.py +++ b/test.py @@ -10,7 +10,7 @@ from metastable_baselines.ppo import PPO from metastable_baselines.sac import SAC from metastable_baselines.ppo.policies import MlpPolicy as MlpPolicyPPO from metastable_baselines.sac.policies import MlpPolicy as MlpPolicySAC -from metastable_baselines.projections import BaseProjectionLayer, FrobeniusProjectionLayer, WassersteinProjectionLayer, KLProjectionLayer +from metastable_projections.projections import BaseProjectionLayer, FrobeniusProjectionLayer, WassersteinProjectionLayer, KLProjectionLayer import columbus from metastable_baselines.distributions import Strength, ParametrizationType, EnforcePositiveType, ProbSquashingType