From c7ca326345897c42bb96f5e341df87b7ca46f4f6 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sat, 3 Sep 2022 11:40:14 +0200 Subject: [PATCH] Fixing dependency problems --- metastable_projections/misc/distTools.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/metastable_projections/misc/distTools.py b/metastable_projections/misc/distTools.py index da8b6ce..67c21b5 100644 --- a/metastable_projections/misc/distTools.py +++ b/metastable_projections/misc/distTools.py @@ -1,8 +1,15 @@ +from typing import Any, Dict, List, Optional, Tuple, Union + import torch as th from stable_baselines3.common.distributions import Distribution as SB3_Distribution -from ..distributions import UniversalGaussianDistribution, AnyDistribution + +class UniversalGaussianDistribution(SB3_Distribution): + pass + + +AnyDistribution = Union[SB3_Distribution, UniversalGaussianDistribution] def get_mean_and_chol(p: AnyDistribution, expand=False):