EnforcePositiveType makes no sense for Strength.NONE
This commit is contained in:
parent
9133ecd61b
commit
b7de99b1fc
@ -79,13 +79,16 @@ def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStreng
|
|||||||
for cs in allowedCovStrength:
|
for cs in allowedCovStrength:
|
||||||
if ps.value > cs.value:
|
if ps.value > cs.value:
|
||||||
continue
|
continue
|
||||||
for ept in allowedEPTs:
|
if cs == Strength.NONE:
|
||||||
if cs == Strength.FULL:
|
yield (ps, cs, EnforcePositiveType.NONE, ParametrizationType.NONE)
|
||||||
for pt in allowedPTs:
|
else:
|
||||||
if pt != ParametrizationType.NONE:
|
for ept in allowedEPTs:
|
||||||
yield (ps, cs, ept, pt)
|
if cs == Strength.FULL:
|
||||||
else:
|
for pt in allowedPTs:
|
||||||
yield (ps, cs, ept, ParametrizationType.NONE)
|
if pt != ParametrizationType.NONE:
|
||||||
|
yield (ps, cs, ept, pt)
|
||||||
|
else:
|
||||||
|
yield (ps, cs, ept, ParametrizationType.NONE)
|
||||||
|
|
||||||
|
|
||||||
def make_proba_distribution(
|
def make_proba_distribution(
|
||||||
@ -147,6 +150,10 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
|||||||
assert (parameterization_type != ParametrizationType.NONE) == (
|
assert (parameterization_type != ParametrizationType.NONE) == (
|
||||||
cov_strength == Strength.FULL), 'You should set an ParameterizationType iff the cov-strength is full'
|
cov_strength == Strength.FULL), 'You should set an ParameterizationType iff the cov-strength is full'
|
||||||
|
|
||||||
|
if self.par_type == ParametrizationType.SPHERICAL_CHOL and self.enforce_positive_type == EnforcePositiveType.NONE:
|
||||||
|
raise Exception(
|
||||||
|
'You need to specify an enforce_positive_type for spherical_cholesky')
|
||||||
|
|
||||||
def new_dist_like_me(self, mean: th.Tensor, chol: th.Tensor):
|
def new_dist_like_me(self, mean: th.Tensor, chol: th.Tensor):
|
||||||
p = self.distribution
|
p = self.distribution
|
||||||
if isinstance(p, Independent):
|
if isinstance(p, Independent):
|
||||||
@ -223,9 +230,9 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
|||||||
def mode(self) -> th.Tensor:
|
def mode(self) -> th.Tensor:
|
||||||
return self.distribution.mean
|
return self.distribution.mean
|
||||||
|
|
||||||
def actions_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
def actions_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor, deterministic: bool = False, latent_pi=None) -> th.Tensor:
|
||||||
# Update the proba distribution
|
# Update the proba distribution
|
||||||
self.proba_distribution(mean_actions, log_std)
|
self.proba_distribution(mean_actions, log_std, latent_pi=latent_pi)
|
||||||
return self.get_actions(deterministic=deterministic)
|
return self.get_actions(deterministic=deterministic)
|
||||||
|
|
||||||
def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
||||||
|
Loading…
Reference in New Issue
Block a user