Added support for parallel envs
This commit is contained in:
parent
5c39be5ead
commit
02e4ed1510
@ -22,13 +22,6 @@ from stable_baselines3.common.distributions import DiagGaussianDistribution
|
|||||||
from ..misc.tensor_ops import fill_triangular
|
from ..misc.tensor_ops import fill_triangular
|
||||||
from ..misc.tanhBijector import TanhBijector
|
from ..misc.tanhBijector import TanhBijector
|
||||||
|
|
||||||
# TODO: Integrate and Test what I currently have before adding more complexity
|
|
||||||
# TODO: Support Squashed Dists (tanh)
|
|
||||||
# TODO: Contextual Cov
|
|
||||||
# TODO: - Hybrid
|
|
||||||
# TODO: Contextual SDE (Scalar + Diag + Full)
|
|
||||||
# TODO: (SqrtInducedCov (Scalar + Diag + Full))
|
|
||||||
|
|
||||||
|
|
||||||
class Strength(Enum):
|
class Strength(Enum):
|
||||||
NONE = 0
|
NONE = 0
|
||||||
@ -220,17 +213,13 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
|||||||
return mean_actions, chol
|
return mean_actions, chol
|
||||||
|
|
||||||
def _sqrt_to_chol(self, cov_sqrt):
|
def _sqrt_to_chol(self, cov_sqrt):
|
||||||
vec = False
|
vec = self.cov_strength != Strength.FULL
|
||||||
nobatch = False
|
batch_dims = len(cov_sqrt.shape) - 2 + 1*vec
|
||||||
if len(cov_sqrt.shape) <= 2:
|
|
||||||
vec = True
|
|
||||||
if len(cov_sqrt.shape) == 1:
|
|
||||||
nobatch = True
|
|
||||||
|
|
||||||
if vec:
|
if vec:
|
||||||
cov_sqrt = th.diag_embed(cov_sqrt)
|
cov_sqrt = th.diag_embed(cov_sqrt)
|
||||||
|
|
||||||
if nobatch:
|
if batch_dims == 0:
|
||||||
cov = th.mm(cov_sqrt.mT, cov_sqrt)
|
cov = th.mm(cov_sqrt.mT, cov_sqrt)
|
||||||
cov += th.eye(cov.shape[-1])*(self.epsilon)
|
cov += th.eye(cov.shape[-1])*(self.epsilon)
|
||||||
else:
|
else:
|
||||||
@ -533,8 +522,12 @@ class CholNet(nn.Module):
|
|||||||
# S[i,j] e (0, pi) where i = 2..n, j = 2..i
|
# S[i,j] e (0, pi) where i = 2..n, j = 2..i
|
||||||
# We already ensure S > 0 in _chol_from_flat_sphe_chol
|
# We already ensure S > 0 in _chol_from_flat_sphe_chol
|
||||||
# We ensure < pi by applying tanh*pi to all applicable elements
|
# We ensure < pi by applying tanh*pi to all applicable elements
|
||||||
batch = (len(sphe_chol.shape) == 3)
|
vec = self.cov_strength != Strength.FULL
|
||||||
batch_size = sphe_chol.shape[0]
|
batch_dims = len(sphe_chol.shape) - 2 + 1*vec
|
||||||
|
batch = batch_dims != 0
|
||||||
|
batch_shape = sphe_chol.shape[:batch_dims]
|
||||||
|
batch_shape_scalar = batch_shape + (1,)
|
||||||
|
|
||||||
S = sphe_chol
|
S = sphe_chol
|
||||||
n = sphe_chol.shape[-1]
|
n = sphe_chol.shape[-1]
|
||||||
L = th.zeros_like(sphe_chol)
|
L = th.zeros_like(sphe_chol)
|
||||||
@ -542,13 +535,13 @@ class CholNet(nn.Module):
|
|||||||
#t = 1
|
#t = 1
|
||||||
t = th.Tensor([1])[0]
|
t = th.Tensor([1])[0]
|
||||||
if batch:
|
if batch:
|
||||||
t = t.expand((batch_size, 1))
|
t = t.expand(batch_shape_scalar)
|
||||||
#s = ''
|
#s = ''
|
||||||
for j in range(i+1):
|
for j in range(i+1):
|
||||||
#maybe_cos = 1
|
#maybe_cos = 1
|
||||||
maybe_cos = th.Tensor([1])[0]
|
maybe_cos = th.Tensor([1])[0]
|
||||||
if batch:
|
if batch:
|
||||||
maybe_cos = maybe_cos.expand((batch_size, 1))
|
maybe_cos = maybe_cos.expand(batch_shape_scalar)
|
||||||
#s_maybe_cos = ''
|
#s_maybe_cos = ''
|
||||||
if i != j and j < n-1 and i < n:
|
if i != j and j < n-1 and i < n:
|
||||||
if batch:
|
if batch:
|
||||||
|
@ -101,7 +101,7 @@ class ActorCriticPolicy(BasePolicy):
|
|||||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
dist_kwargs: Optional[Dict[str, Any]] = None,
|
dist_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
sqrt_induced_gaussian=False,
|
sqrt_induced_gaussian: bool = False,
|
||||||
):
|
):
|
||||||
|
|
||||||
if optimizer_kwargs is None:
|
if optimizer_kwargs is None:
|
||||||
|
Loading…
Reference in New Issue
Block a user