EnforcePositiveType makes no sense for Strength.NONE

This commit is contained in:
Dominik Moritz Roth 2022-07-19 10:06:40 +02:00
parent 9133ecd61b
commit b7de99b1fc

View File

@ -79,13 +79,16 @@ def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStreng
for cs in allowedCovStrength: for cs in allowedCovStrength:
if ps.value > cs.value: if ps.value > cs.value:
continue continue
for ept in allowedEPTs: if cs == Strength.NONE:
if cs == Strength.FULL: yield (ps, cs, EnforcePositiveType.NONE, ParametrizationType.NONE)
for pt in allowedPTs: else:
if pt != ParametrizationType.NONE: for ept in allowedEPTs:
yield (ps, cs, ept, pt) if cs == Strength.FULL:
else: for pt in allowedPTs:
yield (ps, cs, ept, ParametrizationType.NONE) if pt != ParametrizationType.NONE:
yield (ps, cs, ept, pt)
else:
yield (ps, cs, ept, ParametrizationType.NONE)
def make_proba_distribution( def make_proba_distribution(
@ -147,6 +150,10 @@ class UniversalGaussianDistribution(SB3_Distribution):
assert (parameterization_type != ParametrizationType.NONE) == ( assert (parameterization_type != ParametrizationType.NONE) == (
cov_strength == Strength.FULL), 'You should set an ParameterizationType iff the cov-strength is full' cov_strength == Strength.FULL), 'You should set an ParameterizationType iff the cov-strength is full'
if self.par_type == ParametrizationType.SPHERICAL_CHOL and self.enforce_positive_type == EnforcePositiveType.NONE:
raise Exception(
'You need to specify an enforce_positive_type for spherical_cholesky')
def new_dist_like_me(self, mean: th.Tensor, chol: th.Tensor): def new_dist_like_me(self, mean: th.Tensor, chol: th.Tensor):
p = self.distribution p = self.distribution
if isinstance(p, Independent): if isinstance(p, Independent):
@ -223,9 +230,9 @@ class UniversalGaussianDistribution(SB3_Distribution):
def mode(self) -> th.Tensor: def mode(self) -> th.Tensor:
return self.distribution.mean return self.distribution.mean
def actions_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor, deterministic: bool = False) -> th.Tensor: def actions_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor, deterministic: bool = False, latent_pi=None) -> th.Tensor:
# Update the proba distribution # Update the proba distribution
self.proba_distribution(mean_actions, log_std) self.proba_distribution(mean_actions, log_std, latent_pi=latent_pi)
return self.get_actions(deterministic=deterministic) return self.get_actions(deterministic=deterministic)
def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: