Factor Projections out into metastable-projections
This commit is contained in:
parent
0a037deccc
commit
4bb772a251
@ -37,7 +37,7 @@ from stable_baselines3.common.torch_layers import (
|
||||
|
||||
from stable_baselines3.common.preprocessing import get_action_dim
|
||||
|
||||
from metastable_baselines.projections.w2_projection_layer import WassersteinProjectionLayer
|
||||
from metastable_projections.projections.w2_projection_layer import WassersteinProjectionLayer
|
||||
|
||||
from ..distributions import UniversalGaussianDistribution, make_proba_distribution
|
||||
from ..misc.distTools import get_mean_and_chol
|
||||
|
@ -18,10 +18,10 @@ from stable_baselines3.common.vec_env import VecNormalize
|
||||
|
||||
from ..misc.distTools import new_dist_like, new_dist_like_from_sqrt
|
||||
|
||||
from ..projections.base_projection_layer import BaseProjectionLayer
|
||||
from ..projections.frob_projection_layer import FrobeniusProjectionLayer
|
||||
from ..projections.w2_projection_layer import WassersteinProjectionLayer
|
||||
from ..projections.kl_projection_layer import KLProjectionLayer
|
||||
from metastable_projections.projections.base_projection_layer import BaseProjectionLayer
|
||||
from metastable_projections.projections.frob_projection_layer import FrobeniusProjectionLayer
|
||||
from metastable_projections.projections.w2_projection_layer import WassersteinProjectionLayer
|
||||
from metastable_projections.projections.kl_projection_layer import KLProjectionLayer
|
||||
|
||||
from ..misc.rollout_buffer import GaussianRolloutCollectorAuxclass
|
||||
|
||||
|
@ -16,10 +16,10 @@ from metastable_baselines.sac.policies import CnnPolicy, MlpPolicy, MultiInputPo
|
||||
|
||||
from ..misc.distTools import new_dist_like
|
||||
|
||||
from ..projections.base_projection_layer import BaseProjectionLayer
|
||||
from ..projections.frob_projection_layer import FrobeniusProjectionLayer
|
||||
from ..projections.w2_projection_layer import WassersteinProjectionLayer
|
||||
from ..projections.kl_projection_layer import KLProjectionLayer
|
||||
from metastable_projections.projections.base_projection_layer import BaseProjectionLayer
|
||||
from metastable_projections.projections.frob_projection_layer import FrobeniusProjectionLayer
|
||||
from metastable_projections.projections.w2_projection_layer import WassersteinProjectionLayer
|
||||
from metastable_projections.projections.kl_projection_layer import KLProjectionLayer
|
||||
|
||||
|
||||
from ..misc.rollout_buffer import GaussianRolloutCollectorAuxclass
|
||||
|
2
test.py
2
test.py
@ -10,7 +10,7 @@ from metastable_baselines.ppo import PPO
|
||||
from metastable_baselines.sac import SAC
|
||||
from metastable_baselines.ppo.policies import MlpPolicy as MlpPolicyPPO
|
||||
from metastable_baselines.sac.policies import MlpPolicy as MlpPolicySAC
|
||||
from metastable_baselines.projections import BaseProjectionLayer, FrobeniusProjectionLayer, WassersteinProjectionLayer, KLProjectionLayer
|
||||
from metastable_projections.projections import BaseProjectionLayer, FrobeniusProjectionLayer, WassersteinProjectionLayer, KLProjectionLayer
|
||||
import columbus
|
||||
|
||||
from metastable_baselines.distributions import Strength, ParametrizationType, EnforcePositiveType, ProbSquashingType
|
||||
|
Loading…
Reference in New Issue
Block a user