diff --git a/metastable_baselines2/common/policies.py b/metastable_baselines2/common/policies.py index 9190d3f..40db45d 100644 --- a/metastable_baselines2/common/policies.py +++ b/metastable_baselines2/common/policies.py @@ -37,7 +37,7 @@ from stable_baselines3.common.policies import ContinuousCritic from stable_baselines3.common.type_aliases import Schedule, RolloutBufferSamples from stable_baselines3.common.utils import get_device, is_vectorized_observation, obs_as_tensor -from metastable_projections.projections import BaseProjectionLayer, IdentityProjectionLayer, FrobeniusProjectionLay, WassersteinProjectionLayer, KLProjectionLayer +from metastable_projections.projections import BaseProjectionLayer, IdentityProjectionLayer, FrobeniusProjectionLayer, WassersteinProjectionLayer, KLProjectionLayer from .distributions import make_proba_distribution