Added support for parallel envs

This commit is contained in:
Dominik Moritz Roth 2022-08-27 15:19:00 +02:00
parent 5c39be5ead
commit 02e4ed1510
2 changed files with 12 additions and 19 deletions

View File

@ -22,13 +22,6 @@ from stable_baselines3.common.distributions import DiagGaussianDistribution
from ..misc.tensor_ops import fill_triangular
from ..misc.tanhBijector import TanhBijector
# TODO: Integrate and Test what I currently have before adding more complexity
# TODO: Support Squashed Dists (tanh)
# TODO: Contextual Cov
# TODO: - Hybrid
# TODO: Contextual SDE (Scalar + Diag + Full)
# TODO: (SqrtInducedCov (Scalar + Diag + Full))
class Strength(Enum):
NONE = 0
@ -220,17 +213,13 @@ class UniversalGaussianDistribution(SB3_Distribution):
return mean_actions, chol
def _sqrt_to_chol(self, cov_sqrt):
vec = False
nobatch = False
if len(cov_sqrt.shape) <= 2:
vec = True
if len(cov_sqrt.shape) == 1:
nobatch = True
vec = self.cov_strength != Strength.FULL
batch_dims = len(cov_sqrt.shape) - 2 + 1*vec
if vec:
cov_sqrt = th.diag_embed(cov_sqrt)
if nobatch:
if batch_dims == 0:
cov = th.mm(cov_sqrt.mT, cov_sqrt)
cov += th.eye(cov.shape[-1])*(self.epsilon)
else:
@ -533,8 +522,12 @@ class CholNet(nn.Module):
# S[i,j] e (0, pi) where i = 2..n, j = 2..i
# We already ensure S > 0 in _chol_from_flat_sphe_chol
# We ensure < pi by applying tanh*pi to all applicable elements
batch = (len(sphe_chol.shape) == 3)
batch_size = sphe_chol.shape[0]
vec = self.cov_strength != Strength.FULL
batch_dims = len(sphe_chol.shape) - 2 + 1*vec
batch = batch_dims != 0
batch_shape = sphe_chol.shape[:batch_dims]
batch_shape_scalar = batch_shape + (1,)
S = sphe_chol
n = sphe_chol.shape[-1]
L = th.zeros_like(sphe_chol)
@ -542,13 +535,13 @@ class CholNet(nn.Module):
#t = 1
t = th.Tensor([1])[0]
if batch:
t = t.expand((batch_size, 1))
t = t.expand(batch_shape_scalar)
#s = ''
for j in range(i+1):
#maybe_cos = 1
maybe_cos = th.Tensor([1])[0]
if batch:
maybe_cos = maybe_cos.expand((batch_size, 1))
maybe_cos = maybe_cos.expand(batch_shape_scalar)
#s_maybe_cos = ''
if i != j and j < n-1 and i < n:
if batch:

View File

@ -101,7 +101,7 @@ class ActorCriticPolicy(BasePolicy):
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
dist_kwargs: Optional[Dict[str, Any]] = None,
sqrt_induced_gaussian=False,
sqrt_induced_gaussian: bool = False,
):
if optimizer_kwargs is None: