Better handling of diagonal-covariance as vector and matrix
This commit is contained in:
parent
bc61a6db32
commit
f4c87c9cdc
@ -3,9 +3,12 @@ import torch as th
|
|||||||
from stable_baselines3.common.distributions import Distribution as SB3_Distribution
|
from stable_baselines3.common.distributions import Distribution as SB3_Distribution
|
||||||
|
|
||||||
|
|
||||||
def get_mean_and_chol(p):
|
def get_mean_and_chol(p, expand=False):
|
||||||
if isinstance(p, th.distributions.Normal):
|
if isinstance(p, th.distributions.Normal):
|
||||||
return p.mean, p.stddev
|
if expand:
|
||||||
|
return p.mean, th.diag_embed(p.stddev)
|
||||||
|
else:
|
||||||
|
return p.mean, p.stddev
|
||||||
elif isinstance(p, th.distributions.MultivariateNormal):
|
elif isinstance(p, th.distributions.MultivariateNormal):
|
||||||
return p.mean, p.scale_tril
|
return p.mean, p.scale_tril
|
||||||
elif isinstance(p, SB3_Distribution):
|
elif isinstance(p, SB3_Distribution):
|
||||||
@ -16,7 +19,7 @@ def get_mean_and_chol(p):
|
|||||||
|
|
||||||
def get_cov(p):
|
def get_cov(p):
|
||||||
if isinstance(p, th.distributions.Normal):
|
if isinstance(p, th.distributions.Normal):
|
||||||
return th.diag(p.variance)
|
return th.diag_embed(p.variance)
|
||||||
elif isinstance(p, th.distributions.MultivariateNormal):
|
elif isinstance(p, th.distributions.MultivariateNormal):
|
||||||
return p.covariance_matrix
|
return p.covariance_matrix
|
||||||
elif isinstance(p, SB3_Distribution):
|
elif isinstance(p, SB3_Distribution):
|
||||||
@ -27,6 +30,8 @@ def get_cov(p):
|
|||||||
|
|
||||||
def new_dist_like(orig_p, mean, chol):
|
def new_dist_like(orig_p, mean, chol):
|
||||||
if isinstance(orig_p, th.distributions.Normal):
|
if isinstance(orig_p, th.distributions.Normal):
|
||||||
|
if orig_p.stddev.shape != chol.shape:
|
||||||
|
chol = th.diagonal(chol, dim1=1, dim2=2)
|
||||||
return th.distributions.Normal(mean, chol)
|
return th.distributions.Normal(mean, chol)
|
||||||
elif isinstance(orig_p, th.distributions.MultivariateNormal):
|
elif isinstance(orig_p, th.distributions.MultivariateNormal):
|
||||||
return th.distributions.MultivariateNormal(mean, scale_tril=chol)
|
return th.distributions.MultivariateNormal(mean, scale_tril=chol)
|
||||||
|
Loading…
Reference in New Issue
Block a user