Implement support for VecEnvs
This commit is contained in:
parent
c68fd1635d
commit
b0e2bc3a7a
@ -89,6 +89,7 @@ class PCA_Distribution(SB3_Distribution):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
action_dim: int,
|
action_dim: int,
|
||||||
|
n_envs: int=1,
|
||||||
par_strength: Par_Strength = Par_Strength.CONT_DIAG,
|
par_strength: Par_Strength = Par_Strength.CONT_DIAG,
|
||||||
kernel_func=kernel.rbf(),
|
kernel_func=kernel.rbf(),
|
||||||
init_std: float = 1,
|
init_std: float = 1,
|
||||||
@ -96,6 +97,7 @@ class PCA_Distribution(SB3_Distribution):
|
|||||||
window: int = 64,
|
window: int = 64,
|
||||||
epsilon: float = 1e-6,
|
epsilon: float = 1e-6,
|
||||||
skip_conditioning: bool = False,
|
skip_conditioning: bool = False,
|
||||||
|
temporal_gradient_emission: bool = False,
|
||||||
Base_Noise=noise.White_Noise,
|
Base_Noise=noise.White_Noise,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -108,8 +110,9 @@ class PCA_Distribution(SB3_Distribution):
|
|||||||
self.window = window
|
self.window = window
|
||||||
self.epsilon = epsilon
|
self.epsilon = epsilon
|
||||||
self.skip_conditioning = skip_conditioning
|
self.skip_conditioning = skip_conditioning
|
||||||
|
self.temporal_gradient_emission = temporal_gradient_emission
|
||||||
|
|
||||||
self.base_noise = cast_to_Noise(Base_Noise, (1, action_dim))
|
self.base_noise = cast_to_Noise(Base_Noise, (n_envs, action_dim))
|
||||||
|
|
||||||
# Premature optimization is the root of all evil
|
# Premature optimization is the root of all evil
|
||||||
self._build_conditioner()
|
self._build_conditioner()
|
||||||
@ -128,7 +131,16 @@ class PCA_Distribution(SB3_Distribution):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
def log_prob(self, actions: th.Tensor) -> th.Tensor:
|
def log_prob(self, actions: th.Tensor) -> th.Tensor:
|
||||||
return sum_independent_dims(self.distribution.log_prob(actions.to(self.distribution.mean.device)))
|
return self._log_prob(actions, self.distribution)
|
||||||
|
|
||||||
|
def conditioned_log_prob(self, actions: th.Tensor, trajectory: th.Tensor = None) -> th.Tensor:
|
||||||
|
pi_mean, pi_std = self.distribution.mean.cpu(), self.distribution.scale.cpu()
|
||||||
|
rho_mean, rho_std = self._conditioning_engine(trajectory, pi_mean, pi_std)
|
||||||
|
dist = Normal(rho_mean, rho_std)
|
||||||
|
return self._log_prob(dist)
|
||||||
|
|
||||||
|
def _log_prob(self, actions: th.Tensor, dist: Normal):
|
||||||
|
return sum_independent_dims(dist.log_prob(actions.to(dist.mean.device)))
|
||||||
|
|
||||||
def entropy(self) -> th.Tensor:
|
def entropy(self) -> th.Tensor:
|
||||||
return sum_independent_dims(self.distribution.entropy())
|
return sum_independent_dims(self.distribution.entropy())
|
||||||
@ -164,7 +176,13 @@ class PCA_Distribution(SB3_Distribution):
|
|||||||
return self.par_strength not in [Par_Strength.SCALAR, Par_Strength.DIAG]
|
return self.par_strength not in [Par_Strength.SCALAR, Par_Strength.DIAG]
|
||||||
|
|
||||||
def _get_rigged(self, pi_mean, pi_std, rho_mean, rho_std, epsilon=None):
|
def _get_rigged(self, pi_mean, pi_std, rho_mean, rho_std, epsilon=None):
|
||||||
|
# Ugly function to ensure, that the gradients flow as intended for each modus operandi
|
||||||
|
if not self.temporal_gradient_emission or self.skip_conditioning:
|
||||||
with th.no_grad():
|
with th.no_grad():
|
||||||
|
return self._get_emitting_rigged(pi_mean, pi_std, rho_mean, rho_std, epsilon=epsilon).detach()
|
||||||
|
return self._get_emitting_rigged(pi_mean.detach(), pi_std.detach(), rho_mean, rho_std, epsilon=epsilon)
|
||||||
|
|
||||||
|
def _get_emitting_rigged(self, pi_mean, pi_std, rho_mean, rho_std, epsilon=None):
|
||||||
if epsilon == None:
|
if epsilon == None:
|
||||||
epsilon = self.base_noise(pi_mean.shape)
|
epsilon = self.base_noise(pi_mean.shape)
|
||||||
|
|
||||||
@ -177,7 +195,7 @@ class PCA_Distribution(SB3_Distribution):
|
|||||||
|
|
||||||
eta = Pi_mu * Delta + Pi_sigma * epsilon
|
eta = Pi_mu * Delta + Pi_sigma * epsilon
|
||||||
|
|
||||||
return eta.detach()
|
return eta
|
||||||
|
|
||||||
def _pad_and_cut_trajectory(self, traj, value=0):
|
def _pad_and_cut_trajectory(self, traj, value=0):
|
||||||
if traj.shape[-2] < self.window:
|
if traj.shape[-2] < self.window:
|
||||||
@ -194,18 +212,15 @@ class PCA_Distribution(SB3_Distribution):
|
|||||||
return pi_mean, pi_std
|
return pi_mean, pi_std
|
||||||
|
|
||||||
traj = self._pad_and_cut_trajectory(trajectory)
|
traj = self._pad_and_cut_trajectory(trajectory)
|
||||||
|
y = th.cat((traj.transpose(-1, -2), pi_mean.unsqueeze(-1).unsqueeze(0).repeat(traj.shape[0], 1, traj.shape[-2])), dim=1)
|
||||||
# Numpy is fun
|
|
||||||
y_np = np.append(np.swapaxes(traj, -1, -2), np.repeat(np.expand_dims(pi_mean, -1), traj.shape[0], 0), -1)
|
|
||||||
|
|
||||||
with th.no_grad():
|
with th.no_grad():
|
||||||
conditioners = th.Tensor(self._adapt_conditioner(pi_std))
|
conditioners = th.Tensor(self._adapt_conditioner(pi_std))
|
||||||
y = th.Tensor(y_np)
|
|
||||||
|
|
||||||
S = th.cholesky_solve(self.Sig12.expand(conditioners.shape[:-1]).unsqueeze(-1), conditioners).squeeze(-1)
|
S = th.cholesky_solve(self.Sig12.expand(conditioners.shape[:-1]).unsqueeze(-1), conditioners).squeeze(-1)
|
||||||
|
|
||||||
rho_mean = th.einsum('bai,bai->ba', S, y)
|
|
||||||
rho_std = self.Sig22 - (S @ self.Sig12)
|
rho_std = self.Sig22 - (S @ self.Sig12)
|
||||||
|
rho_mean = th.einsum('bai,bai->ba', S, y)
|
||||||
|
|
||||||
return rho_mean, rho_std
|
return rho_mean, rho_std
|
||||||
|
|
||||||
@ -254,13 +269,13 @@ class PCA_Distribution(SB3_Distribution):
|
|||||||
return self.distribution.mean
|
return self.distribution.mean
|
||||||
|
|
||||||
def actions_from_params(
|
def actions_from_params(
|
||||||
self, mean: th.Tensor, std: th.Tensor, deterministic: bool = False
|
self, mean: th.Tensor, std: th.Tensor, deterministic: bool = False, trajectory: th.Tensor = None
|
||||||
) -> th.Tensor:
|
) -> th.Tensor:
|
||||||
self.proba_distribution(mean, std)
|
self.proba_distribution(mean, std)
|
||||||
return self.get_actions(deterministic=deterministic)
|
return self.get_actions(deterministic=deterministic, trajectory=trajectory)
|
||||||
|
|
||||||
def log_prob_from_params(self, mean: th.Tensor, std: th.Tensor):
|
def log_prob_from_params(self, mean: th.Tensor, std: th.Tensor, trajectory: th.Tensor = None):
|
||||||
actions = self.actions_from_params(mean, std)
|
actions = self.actions_from_params(mean, std, trajectory=trajectory)
|
||||||
log_prob = self.log_prob(actions)
|
log_prob = self.log_prob(actions)
|
||||||
return actions, log_prob
|
return actions, log_prob
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user