Fixing dependency problems

This commit is contained in:
Dominik Moritz Roth 2022-09-03 11:40:14 +02:00
parent 6c7fc37116
commit c7ca326345

View File

@ -1,8 +1,15 @@
from typing import Any, Dict, List, Optional, Tuple, Union
import torch as th import torch as th
from stable_baselines3.common.distributions import Distribution as SB3_Distribution from stable_baselines3.common.distributions import Distribution as SB3_Distribution
from ..distributions import UniversalGaussianDistribution, AnyDistribution
class UniversalGaussianDistribution(SB3_Distribution):
pass
AnyDistribution = Union[SB3_Distribution, UniversalGaussianDistribution]
def get_mean_and_chol(p: AnyDistribution, expand=False): def get_mean_and_chol(p: AnyDistribution, expand=False):