Factor Projections out into metastable-projections

This commit is contained in:
Dominik Moritz Roth 2022-09-03 11:37:41 +02:00
parent 0a037deccc
commit 4bb772a251
4 changed files with 10 additions and 10 deletions

View File

@ -37,7 +37,7 @@ from stable_baselines3.common.torch_layers import (
from stable_baselines3.common.preprocessing import get_action_dim
from metastable_baselines.projections.w2_projection_layer import WassersteinProjectionLayer
from metastable_projections.projections.w2_projection_layer import WassersteinProjectionLayer
from ..distributions import UniversalGaussianDistribution, make_proba_distribution
from ..misc.distTools import get_mean_and_chol

View File

@ -18,10 +18,10 @@ from stable_baselines3.common.vec_env import VecNormalize
from ..misc.distTools import new_dist_like, new_dist_like_from_sqrt
from ..projections.base_projection_layer import BaseProjectionLayer
from ..projections.frob_projection_layer import FrobeniusProjectionLayer
from ..projections.w2_projection_layer import WassersteinProjectionLayer
from ..projections.kl_projection_layer import KLProjectionLayer
from metastable_projections.projections.base_projection_layer import BaseProjectionLayer
from metastable_projections.projections.frob_projection_layer import FrobeniusProjectionLayer
from metastable_projections.projections.w2_projection_layer import WassersteinProjectionLayer
from metastable_projections.projections.kl_projection_layer import KLProjectionLayer
from ..misc.rollout_buffer import GaussianRolloutCollectorAuxclass

View File

@ -16,10 +16,10 @@ from metastable_baselines.sac.policies import CnnPolicy, MlpPolicy, MultiInputPo
from ..misc.distTools import new_dist_like
from ..projections.base_projection_layer import BaseProjectionLayer
from ..projections.frob_projection_layer import FrobeniusProjectionLayer
from ..projections.w2_projection_layer import WassersteinProjectionLayer
from ..projections.kl_projection_layer import KLProjectionLayer
from metastable_projections.projections.base_projection_layer import BaseProjectionLayer
from metastable_projections.projections.frob_projection_layer import FrobeniusProjectionLayer
from metastable_projections.projections.w2_projection_layer import WassersteinProjectionLayer
from metastable_projections.projections.kl_projection_layer import KLProjectionLayer
from ..misc.rollout_buffer import GaussianRolloutCollectorAuxclass

View File

@ -10,7 +10,7 @@ from metastable_baselines.ppo import PPO
from metastable_baselines.sac import SAC
from metastable_baselines.ppo.policies import MlpPolicy as MlpPolicyPPO
from metastable_baselines.sac.policies import MlpPolicy as MlpPolicySAC
from metastable_baselines.projections import BaseProjectionLayer, FrobeniusProjectionLayer, WassersteinProjectionLayer, KLProjectionLayer
from metastable_projections.projections import BaseProjectionLayer, FrobeniusProjectionLayer, WassersteinProjectionLayer, KLProjectionLayer
import columbus
from metastable_baselines.distributions import Strength, ParametrizationType, EnforcePositiveType, ProbSquashingType