diff --git a/metastable_baselines/distributions/distributions.py b/metastable_baselines/distributions/distributions.py index 82c1dbc..4f0373f 100644 --- a/metastable_baselines/distributions/distributions.py +++ b/metastable_baselines/distributions/distributions.py @@ -389,7 +389,7 @@ class CholNet(nn.Module): # We already ensure S > 0 in _chol_from_flat_sphe_chol # We ensure < pi by applying tanh*pi to all applicable elements S = sphe_chol - n = self.action_dim + n = sphe_chol.shape[-1] L = th.zeros_like(sphe_chol) for i in range(n): t = 1