From f4c87c9cdc72fd637dfbe81452d414ae38611ee3 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sun, 26 Jun 2022 18:14:12 +0200 Subject: [PATCH] Better handling of diagonal-covariance as vector and matrix --- sb3_trl/misc/distTools.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/sb3_trl/misc/distTools.py b/sb3_trl/misc/distTools.py index 0526850..e7ee3e0 100644 --- a/sb3_trl/misc/distTools.py +++ b/sb3_trl/misc/distTools.py @@ -3,9 +3,12 @@ import torch as th from stable_baselines3.common.distributions import Distribution as SB3_Distribution -def get_mean_and_chol(p): +def get_mean_and_chol(p, expand=False): if isinstance(p, th.distributions.Normal): - return p.mean, p.stddev + if expand: + return p.mean, th.diag_embed(p.stddev) + else: + return p.mean, p.stddev elif isinstance(p, th.distributions.MultivariateNormal): return p.mean, p.scale_tril elif isinstance(p, SB3_Distribution): @@ -16,7 +19,7 @@ def get_mean_and_chol(p): def get_cov(p): if isinstance(p, th.distributions.Normal): - return th.diag(p.variance) + return th.diag_embed(p.variance) elif isinstance(p, th.distributions.MultivariateNormal): return p.covariance_matrix elif isinstance(p, SB3_Distribution): @@ -27,6 +30,8 @@ def get_cov(p): def new_dist_like(orig_p, mean, chol): if isinstance(orig_p, th.distributions.Normal): + if orig_p.stddev.shape != chol.shape: + chol = th.diagonal(chol, dim1=1, dim2=2) return th.distributions.Normal(mean, chol) elif isinstance(orig_p, th.distributions.MultivariateNormal): return th.distributions.MultivariateNormal(mean, scale_tril=chol)