Added support for parallel envs
This commit is contained in:
parent
5c39be5ead
commit
02e4ed1510
@ -22,13 +22,6 @@ from stable_baselines3.common.distributions import DiagGaussianDistribution
|
||||
from ..misc.tensor_ops import fill_triangular
|
||||
from ..misc.tanhBijector import TanhBijector
|
||||
|
||||
# TODO: Integrate and Test what I currently have before adding more complexity
|
||||
# TODO: Support Squashed Dists (tanh)
|
||||
# TODO: Contextual Cov
|
||||
# TODO: - Hybrid
|
||||
# TODO: Contextual SDE (Scalar + Diag + Full)
|
||||
# TODO: (SqrtInducedCov (Scalar + Diag + Full))
|
||||
|
||||
|
||||
class Strength(Enum):
|
||||
NONE = 0
|
||||
@ -220,17 +213,13 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
||||
return mean_actions, chol
|
||||
|
||||
def _sqrt_to_chol(self, cov_sqrt):
|
||||
vec = False
|
||||
nobatch = False
|
||||
if len(cov_sqrt.shape) <= 2:
|
||||
vec = True
|
||||
if len(cov_sqrt.shape) == 1:
|
||||
nobatch = True
|
||||
vec = self.cov_strength != Strength.FULL
|
||||
batch_dims = len(cov_sqrt.shape) - 2 + 1*vec
|
||||
|
||||
if vec:
|
||||
cov_sqrt = th.diag_embed(cov_sqrt)
|
||||
|
||||
if nobatch:
|
||||
if batch_dims == 0:
|
||||
cov = th.mm(cov_sqrt.mT, cov_sqrt)
|
||||
cov += th.eye(cov.shape[-1])*(self.epsilon)
|
||||
else:
|
||||
@ -533,8 +522,12 @@ class CholNet(nn.Module):
|
||||
# S[i,j] e (0, pi) where i = 2..n, j = 2..i
|
||||
# We already ensure S > 0 in _chol_from_flat_sphe_chol
|
||||
# We ensure < pi by applying tanh*pi to all applicable elements
|
||||
batch = (len(sphe_chol.shape) == 3)
|
||||
batch_size = sphe_chol.shape[0]
|
||||
vec = self.cov_strength != Strength.FULL
|
||||
batch_dims = len(sphe_chol.shape) - 2 + 1*vec
|
||||
batch = batch_dims != 0
|
||||
batch_shape = sphe_chol.shape[:batch_dims]
|
||||
batch_shape_scalar = batch_shape + (1,)
|
||||
|
||||
S = sphe_chol
|
||||
n = sphe_chol.shape[-1]
|
||||
L = th.zeros_like(sphe_chol)
|
||||
@ -542,13 +535,13 @@ class CholNet(nn.Module):
|
||||
#t = 1
|
||||
t = th.Tensor([1])[0]
|
||||
if batch:
|
||||
t = t.expand((batch_size, 1))
|
||||
t = t.expand(batch_shape_scalar)
|
||||
#s = ''
|
||||
for j in range(i+1):
|
||||
#maybe_cos = 1
|
||||
maybe_cos = th.Tensor([1])[0]
|
||||
if batch:
|
||||
maybe_cos = maybe_cos.expand((batch_size, 1))
|
||||
maybe_cos = maybe_cos.expand(batch_shape_scalar)
|
||||
#s_maybe_cos = ''
|
||||
if i != j and j < n-1 and i < n:
|
||||
if batch:
|
||||
|
@ -101,7 +101,7 @@ class ActorCriticPolicy(BasePolicy):
|
||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
dist_kwargs: Optional[Dict[str, Any]] = None,
|
||||
sqrt_induced_gaussian=False,
|
||||
sqrt_induced_gaussian: bool = False,
|
||||
):
|
||||
|
||||
if optimizer_kwargs is None:
|
||||
|
Loading…
Reference in New Issue
Block a user