Fixed UniversalGaussianDistribution lost SDE when cloning
This commit is contained in:
parent
0ee65e789b
commit
bb1f9ecf2b
@ -143,7 +143,7 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
|||||||
:param action_dim: Dimension of the action space.
|
:param action_dim: Dimension of the action space.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, action_dim: int, use_sde: bool = False, neural_strength: Strength = Strength.DIAG, cov_strength: Strength = Strength.DIAG, parameterization_type: ParametrizationType = ParametrizationType.NONE, enforce_positive_type: EnforcePositiveType = EnforcePositiveType.ABS, prob_squashing_type: ProbSquashingType = ProbSquashingType.NONE, epsilon=1e-6, sde_learn_features=False, full_sde=None):
|
def __init__(self, action_dim: int, use_sde: bool = False, neural_strength: Strength = Strength.DIAG, cov_strength: Strength = Strength.DIAG, parameterization_type: ParametrizationType = ParametrizationType.NONE, enforce_positive_type: EnforcePositiveType = EnforcePositiveType.ABS, prob_squashing_type: ProbSquashingType = ProbSquashingType.NONE, epsilon=1e-6, sde_learn_features=False):
|
||||||
super(UniversalGaussianDistribution, self).__init__()
|
super(UniversalGaussianDistribution, self).__init__()
|
||||||
self.action_dim = action_dim
|
self.action_dim = action_dim
|
||||||
self.par_strength = cast_to_enum(neural_strength, Strength)
|
self.par_strength = cast_to_enum(neural_strength, Strength)
|
||||||
@ -163,8 +163,7 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
|||||||
self.use_sde = use_sde
|
self.use_sde = use_sde
|
||||||
self.learn_features = sde_learn_features
|
self.learn_features = sde_learn_features
|
||||||
|
|
||||||
if full_sde != None:
|
print('sde', self.use_sde)
|
||||||
print('[!] Argument full_sde is only provided to remain compatible with vanilla SB3 PPO. It does not serve any function!')
|
|
||||||
|
|
||||||
assert (self.par_type != ParametrizationType.NONE) == (
|
assert (self.par_type != ParametrizationType.NONE) == (
|
||||||
self.cov_strength == Strength.FULL), 'You should set an ParameterizationType iff the cov-strength is full'
|
self.cov_strength == Strength.FULL), 'You should set an ParameterizationType iff the cov-strength is full'
|
||||||
@ -181,8 +180,8 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
|||||||
np = Independent(Normal(mean, chol), 1)
|
np = Independent(Normal(mean, chol), 1)
|
||||||
elif isinstance(p, MultivariateNormal):
|
elif isinstance(p, MultivariateNormal):
|
||||||
np = MultivariateNormal(mean, scale_tril=chol)
|
np = MultivariateNormal(mean, scale_tril=chol)
|
||||||
new = UniversalGaussianDistribution(self.action_dim, neural_strength=self.par_strength, cov_strength=self.cov_strength,
|
new = UniversalGaussianDistribution(self.action_dim, use_sde=self.use_sde, neural_strength=self.par_strength, cov_strength=self.cov_strength,
|
||||||
parameterization_type=self.par_type, enforce_positive_type=self.enforce_positive_type, prob_squashing_type=self.prob_squashing_type)
|
parameterization_type=self.par_type, enforce_positive_type=self.enforce_positive_type, prob_squashing_type=self.prob_squashing_type, epsilon=self.epsilon, sde_learn_features=self.learn_features)
|
||||||
new.distribution = np
|
new.distribution = np
|
||||||
|
|
||||||
return new
|
return new
|
||||||
|
Loading…
Reference in New Issue
Block a user