diff --git a/metastable_projections/misc/distTools.py b/metastable_projections/misc/distTools.py index da8b6ce..67c21b5 100644 --- a/metastable_projections/misc/distTools.py +++ b/metastable_projections/misc/distTools.py @@ -1,8 +1,15 @@ +from typing import Any, Dict, List, Optional, Tuple, Union + import torch as th from stable_baselines3.common.distributions import Distribution as SB3_Distribution -from ..distributions import UniversalGaussianDistribution, AnyDistribution + +class UniversalGaussianDistribution(SB3_Distribution): + pass + + +AnyDistribution = Union[SB3_Distribution, UniversalGaussianDistribution] def get_mean_and_chol(p: AnyDistribution, expand=False):