First implementation of hybrid (bewteen sde and rex)
This commit is contained in:
parent
ae9a95cbfc
commit
c62723bef6
@ -136,7 +136,7 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
|||||||
:param action_dim: Dimension of the action space.
|
:param action_dim: Dimension of the action space.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, action_dim: int, use_sde: bool = False, neural_strength: Strength = Strength.DIAG, cov_strength: Strength = Strength.DIAG, parameterization_type: ParametrizationType = ParametrizationType.NONE, enforce_positive_type: EnforcePositiveType = EnforcePositiveType.ABS, prob_squashing_type: ProbSquashingType = ProbSquashingType.NONE, epsilon=1e-3, sde_learn_features=False, sde_latent_softmax=False):
|
def __init__(self, action_dim: int, use_sde: bool = False, neural_strength: Strength = Strength.DIAG, cov_strength: Strength = Strength.DIAG, parameterization_type: ParametrizationType = ParametrizationType.NONE, enforce_positive_type: EnforcePositiveType = EnforcePositiveType.ABS, prob_squashing_type: ProbSquashingType = ProbSquashingType.NONE, epsilon=1e-3, sde_learn_features=False, sde_latent_softmax=False, use_hybrid=False, hybrid_rex_fac=0.5):
|
||||||
super(UniversalGaussianDistribution, self).__init__()
|
super(UniversalGaussianDistribution, self).__init__()
|
||||||
self.action_dim = action_dim
|
self.action_dim = action_dim
|
||||||
self.par_strength = cast_to_enum(neural_strength, Strength)
|
self.par_strength = cast_to_enum(neural_strength, Strength)
|
||||||
@ -154,9 +154,15 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
|||||||
self.gaussian_actions = None
|
self.gaussian_actions = None
|
||||||
|
|
||||||
self.use_sde = use_sde
|
self.use_sde = use_sde
|
||||||
|
self.use_hybrid = use_hybrid
|
||||||
|
self.hybrid_rex_fac = hybrid_rex_fac
|
||||||
self.learn_features = sde_learn_features
|
self.learn_features = sde_learn_features
|
||||||
self.sde_latent_softmax = sde_latent_softmax
|
self.sde_latent_softmax = sde_latent_softmax
|
||||||
|
|
||||||
|
if self.use_hybrid and not self.use_sde:
|
||||||
|
print('[!] use_hybrid forces use_sde to be true')
|
||||||
|
self.use_sde = True
|
||||||
|
|
||||||
assert (self.par_type != ParametrizationType.NONE) == (
|
assert (self.par_type != ParametrizationType.NONE) == (
|
||||||
self.cov_strength == Strength.FULL), 'You should set an ParameterizationType iff the cov-strength is full'
|
self.cov_strength == Strength.FULL), 'You should set an ParameterizationType iff the cov-strength is full'
|
||||||
|
|
||||||
@ -298,7 +304,9 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
|||||||
return self.distribution.entropy()
|
return self.distribution.entropy()
|
||||||
|
|
||||||
def sample(self) -> th.Tensor:
|
def sample(self) -> th.Tensor:
|
||||||
if self.use_sde:
|
if self.use_hybrid:
|
||||||
|
return self._sample_hybrid()
|
||||||
|
elif self.use_sde:
|
||||||
return self._sample_sde()
|
return self._sample_sde()
|
||||||
else:
|
else:
|
||||||
return self._sample_normal()
|
return self._sample_normal()
|
||||||
@ -315,6 +323,11 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
|||||||
self.gaussian_actions = actions
|
self.gaussian_actions = actions
|
||||||
return self.prob_squashing_type.apply(actions)
|
return self.prob_squashing_type.apply(actions)
|
||||||
|
|
||||||
|
def _sample_hybrid(self) -> th.Tensor:
|
||||||
|
f = self.hybrid_rex_factor
|
||||||
|
actions = self._sample_normal()*f + self._sample_sde()*(1-f)
|
||||||
|
return self.prob_squashing_type.apply(actions)
|
||||||
|
|
||||||
def mode(self) -> th.Tensor:
|
def mode(self) -> th.Tensor:
|
||||||
mode = self.distribution.mean
|
mode = self.distribution.mean
|
||||||
self.gaussian_actions = mode
|
self.gaussian_actions = mode
|
||||||
|
Loading…
Reference in New Issue
Block a user