EnforcePositiveType makes no sense for Strength.NONE

This commit is contained in:
Dominik Moritz Roth 2022-07-19 10:06:40 +02:00
parent 9133ecd61b
commit b7de99b1fc

View File

@ -79,6 +79,9 @@ def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStreng
for cs in allowedCovStrength:
if ps.value > cs.value:
continue
if cs == Strength.NONE:
yield (ps, cs, EnforcePositiveType.NONE, ParametrizationType.NONE)
else:
for ept in allowedEPTs:
if cs == Strength.FULL:
for pt in allowedPTs:
@ -147,6 +150,10 @@ class UniversalGaussianDistribution(SB3_Distribution):
assert (parameterization_type != ParametrizationType.NONE) == (
cov_strength == Strength.FULL), 'You should set an ParameterizationType iff the cov-strength is full'
if self.par_type == ParametrizationType.SPHERICAL_CHOL and self.enforce_positive_type == EnforcePositiveType.NONE:
raise Exception(
'You need to specify an enforce_positive_type for spherical_cholesky')
def new_dist_like_me(self, mean: th.Tensor, chol: th.Tensor):
p = self.distribution
if isinstance(p, Independent):
@ -223,9 +230,9 @@ class UniversalGaussianDistribution(SB3_Distribution):
def mode(self) -> th.Tensor:
return self.distribution.mean
def actions_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor, deterministic: bool = False) -> th.Tensor:
def actions_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor, deterministic: bool = False, latent_pi=None) -> th.Tensor:
# Update the proba distribution
self.proba_distribution(mean_actions, log_std)
self.proba_distribution(mean_actions, log_std, latent_pi=latent_pi)
return self.get_actions(deterministic=deterministic)
def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: