diff --git a/metastable_baselines/distributions/distributions.py b/metastable_baselines/distributions/distributions.py index d656542..22dc522 100644 --- a/metastable_baselines/distributions/distributions.py +++ b/metastable_baselines/distributions/distributions.py @@ -136,7 +136,7 @@ class UniversalGaussianDistribution(SB3_Distribution): :param action_dim: Dimension of the action space. """ - def __init__(self, action_dim: int, use_sde: bool = False, neural_strength: Strength = Strength.DIAG, cov_strength: Strength = Strength.DIAG, parameterization_type: ParametrizationType = ParametrizationType.NONE, enforce_positive_type: EnforcePositiveType = EnforcePositiveType.ABS, prob_squashing_type: ProbSquashingType = ProbSquashingType.NONE, epsilon=1e-3, sde_learn_features=False, sde_latent_softmax=False): + def __init__(self, action_dim: int, use_sde: bool = False, neural_strength: Strength = Strength.DIAG, cov_strength: Strength = Strength.DIAG, parameterization_type: ParametrizationType = ParametrizationType.NONE, enforce_positive_type: EnforcePositiveType = EnforcePositiveType.ABS, prob_squashing_type: ProbSquashingType = ProbSquashingType.NONE, epsilon=1e-3, sde_learn_features=False, sde_latent_softmax=False, use_hybrid=False, hybrid_rex_fac=0.5): super(UniversalGaussianDistribution, self).__init__() self.action_dim = action_dim self.par_strength = cast_to_enum(neural_strength, Strength) @@ -154,9 +154,15 @@ class UniversalGaussianDistribution(SB3_Distribution): self.gaussian_actions = None self.use_sde = use_sde + self.use_hybrid = use_hybrid + self.hybrid_rex_fac = hybrid_rex_fac self.learn_features = sde_learn_features self.sde_latent_softmax = sde_latent_softmax + if self.use_hybrid and not self.use_sde: + print('[!] use_hybrid forces use_sde to be true') + self.use_sde = True + assert (self.par_type != ParametrizationType.NONE) == ( self.cov_strength == Strength.FULL), 'You should set an ParameterizationType iff the cov-strength is full' @@ -298,7 +304,9 @@ class UniversalGaussianDistribution(SB3_Distribution): return self.distribution.entropy() def sample(self) -> th.Tensor: - if self.use_sde: + if self.use_hybrid: + return self._sample_hybrid() + elif self.use_sde: return self._sample_sde() else: return self._sample_normal() @@ -315,6 +323,11 @@ class UniversalGaussianDistribution(SB3_Distribution): self.gaussian_actions = actions return self.prob_squashing_type.apply(actions) + def _sample_hybrid(self) -> th.Tensor: + f = self.hybrid_rex_factor + actions = self._sample_normal()*f + self._sample_sde()*(1-f) + return self.prob_squashing_type.apply(actions) + def mode(self) -> th.Tensor: mode = self.distribution.mean self.gaussian_actions = mode