From 74697e8773b26cb7864664bc87066838fa0533a6 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Fri, 15 Jul 2022 18:46:17 +0200 Subject: [PATCH] Smol bug fixes --- metastable_baselines/misc/distTools.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/metastable_baselines/misc/distTools.py b/metastable_baselines/misc/distTools.py index 2e2c0d0..e7ed925 100644 --- a/metastable_baselines/misc/distTools.py +++ b/metastable_baselines/misc/distTools.py @@ -1,3 +1,4 @@ +from sympy import numer import torch as th from stable_baselines3.common.distributions import Distribution as SB3_Distribution @@ -42,7 +43,7 @@ def get_cov(p: AnyDistribution): raise Exception('Dist-Type not implemented') -def has_diag_cov(p: AnyDistribution, numerical_check=True): +def has_diag_cov(p: AnyDistribution, numerical_check=False): if isinstance(p, SB3_Distribution): return has_diag_cov(p.distribution, numerical_check=numerical_check) if isinstance(p, th.distributions.Normal) or isinstance(p, th.distributions.Independent): @@ -51,7 +52,7 @@ def has_diag_cov(p: AnyDistribution, numerical_check=True): return False # Check if matrix is diag cov = get_cov(p) - return th.equal(cov - th.diag_embed(th.diagonal(cov, dim1=-2, dim2=-1), th.zeros_like(cov))) + return th.equal(cov - th.diag_embed(th.diagonal(cov, dim1=-2, dim2=-1)), th.zeros_like(cov)) def is_contextual(p: AnyDistribution): @@ -59,8 +60,8 @@ def is_contextual(p: AnyDistribution): return False -def get_diag_cov_vec(p: AnyDistribution, check_diag=True, numerical_check=True): - if check_diag and not has_diag_cov(p): +def get_diag_cov_vec(p: AnyDistribution, check_diag=True, numerical_check=False): + if check_diag and not has_diag_cov(p, numerical_check=numerical_check): raise Exception('Cannot reduce cov-mat to diag-vec: Is not diagonal') return th.diagonal(get_cov(p), dim1=-2, dim2=-1)