Smol bug fixes
This commit is contained in:
parent
2e0f46b0f3
commit
74697e8773
@ -1,3 +1,4 @@
|
|||||||
|
from sympy import numer
|
||||||
import torch as th
|
import torch as th
|
||||||
|
|
||||||
from stable_baselines3.common.distributions import Distribution as SB3_Distribution
|
from stable_baselines3.common.distributions import Distribution as SB3_Distribution
|
||||||
@ -42,7 +43,7 @@ def get_cov(p: AnyDistribution):
|
|||||||
raise Exception('Dist-Type not implemented')
|
raise Exception('Dist-Type not implemented')
|
||||||
|
|
||||||
|
|
||||||
def has_diag_cov(p: AnyDistribution, numerical_check=True):
|
def has_diag_cov(p: AnyDistribution, numerical_check=False):
|
||||||
if isinstance(p, SB3_Distribution):
|
if isinstance(p, SB3_Distribution):
|
||||||
return has_diag_cov(p.distribution, numerical_check=numerical_check)
|
return has_diag_cov(p.distribution, numerical_check=numerical_check)
|
||||||
if isinstance(p, th.distributions.Normal) or isinstance(p, th.distributions.Independent):
|
if isinstance(p, th.distributions.Normal) or isinstance(p, th.distributions.Independent):
|
||||||
@ -51,7 +52,7 @@ def has_diag_cov(p: AnyDistribution, numerical_check=True):
|
|||||||
return False
|
return False
|
||||||
# Check if matrix is diag
|
# Check if matrix is diag
|
||||||
cov = get_cov(p)
|
cov = get_cov(p)
|
||||||
return th.equal(cov - th.diag_embed(th.diagonal(cov, dim1=-2, dim2=-1), th.zeros_like(cov)))
|
return th.equal(cov - th.diag_embed(th.diagonal(cov, dim1=-2, dim2=-1)), th.zeros_like(cov))
|
||||||
|
|
||||||
|
|
||||||
def is_contextual(p: AnyDistribution):
|
def is_contextual(p: AnyDistribution):
|
||||||
@ -59,8 +60,8 @@ def is_contextual(p: AnyDistribution):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def get_diag_cov_vec(p: AnyDistribution, check_diag=True, numerical_check=True):
|
def get_diag_cov_vec(p: AnyDistribution, check_diag=True, numerical_check=False):
|
||||||
if check_diag and not has_diag_cov(p):
|
if check_diag and not has_diag_cov(p, numerical_check=numerical_check):
|
||||||
raise Exception('Cannot reduce cov-mat to diag-vec: Is not diagonal')
|
raise Exception('Cannot reduce cov-mat to diag-vec: Is not diagonal')
|
||||||
return th.diagonal(get_cov(p), dim1=-2, dim2=-1)
|
return th.diagonal(get_cov(p), dim1=-2, dim2=-1)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user