47 lines
1.7 KiB
Python
47 lines
1.7 KiB
Python
import torch as th
|
|
|
|
from stable_baselines3.common.distributions import Distribution as SB3_Distribution
|
|
|
|
|
|
def get_mean_and_chol(p):
|
|
if isinstance(p, th.distributions.Normal):
|
|
return p.mean, p.stddev
|
|
elif isinstance(p, th.distributions.MultivariateNormal):
|
|
return p.mean, p.scale_tril
|
|
elif isinstance(p, SB3_Distribution):
|
|
return get_mean_and_chol(p.distribution)
|
|
else:
|
|
raise Exception('Dist-Type not implemented')
|
|
|
|
|
|
def get_cov(p):
|
|
if isinstance(p, th.distributions.Normal):
|
|
return th.diag(p.variance)
|
|
elif isinstance(p, th.distributions.MultivariateNormal):
|
|
return p.covariance_matrix
|
|
elif isinstance(p, SB3_Distribution):
|
|
return get_cov(p.distribution)
|
|
else:
|
|
raise Exception('Dist-Type not implemented')
|
|
|
|
|
|
def new_dist_like(orig_p, mean, chol):
|
|
if isinstance(orig_p, th.distributions.Normal):
|
|
return th.distributions.Normal(mean, chol)
|
|
elif isinstance(orig_p, th.distributions.MultivariateNormal):
|
|
return th.distributions.MultivariateNormal(mean, scale_tril=chol)
|
|
elif isinstance(orig_p, SB3_Distribution):
|
|
p = orig_p.distribution
|
|
if isinstance(p, th.distributions.Normal):
|
|
p_out = orig_p.__class__(orig_p.action_dim)
|
|
p_out.distribution = th.distributions.Normal(mean, chol)
|
|
elif isinstance(p, th.distributions.MultivariateNormal):
|
|
p_out = orig_p.__class__(orig_p.action_dim)
|
|
p_out.distribution = th.distributions.MultivariateNormal(
|
|
mean, scale_tril=chol)
|
|
else:
|
|
raise Exception('Dist-Type not implemented (of sb3 dist)')
|
|
return p_out
|
|
else:
|
|
raise Exception('Dist-Type not implemented')
|