EnforcePositiveType makes no sense for Strength.NONE
This commit is contained in:
parent
9133ecd61b
commit
b7de99b1fc
@ -79,6 +79,9 @@ def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStreng
|
||||
for cs in allowedCovStrength:
|
||||
if ps.value > cs.value:
|
||||
continue
|
||||
if cs == Strength.NONE:
|
||||
yield (ps, cs, EnforcePositiveType.NONE, ParametrizationType.NONE)
|
||||
else:
|
||||
for ept in allowedEPTs:
|
||||
if cs == Strength.FULL:
|
||||
for pt in allowedPTs:
|
||||
@ -147,6 +150,10 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
||||
assert (parameterization_type != ParametrizationType.NONE) == (
|
||||
cov_strength == Strength.FULL), 'You should set an ParameterizationType iff the cov-strength is full'
|
||||
|
||||
if self.par_type == ParametrizationType.SPHERICAL_CHOL and self.enforce_positive_type == EnforcePositiveType.NONE:
|
||||
raise Exception(
|
||||
'You need to specify an enforce_positive_type for spherical_cholesky')
|
||||
|
||||
def new_dist_like_me(self, mean: th.Tensor, chol: th.Tensor):
|
||||
p = self.distribution
|
||||
if isinstance(p, Independent):
|
||||
@ -223,9 +230,9 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
||||
def mode(self) -> th.Tensor:
|
||||
return self.distribution.mean
|
||||
|
||||
def actions_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||
def actions_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor, deterministic: bool = False, latent_pi=None) -> th.Tensor:
|
||||
# Update the proba distribution
|
||||
self.proba_distribution(mean_actions, log_std)
|
||||
self.proba_distribution(mean_actions, log_std, latent_pi=latent_pi)
|
||||
return self.get_actions(deterministic=deterministic)
|
||||
|
||||
def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
||||
|
Loading…
Reference in New Issue
Block a user