Implement support for VecEnvs
This commit is contained in:
parent
c68fd1635d
commit
b0e2bc3a7a
@ -89,6 +89,7 @@ class PCA_Distribution(SB3_Distribution):
|
||||
def __init__(
|
||||
self,
|
||||
action_dim: int,
|
||||
n_envs: int=1,
|
||||
par_strength: Par_Strength = Par_Strength.CONT_DIAG,
|
||||
kernel_func=kernel.rbf(),
|
||||
init_std: float = 1,
|
||||
@ -96,6 +97,7 @@ class PCA_Distribution(SB3_Distribution):
|
||||
window: int = 64,
|
||||
epsilon: float = 1e-6,
|
||||
skip_conditioning: bool = False,
|
||||
temporal_gradient_emission: bool = False,
|
||||
Base_Noise=noise.White_Noise,
|
||||
):
|
||||
super().__init__()
|
||||
@ -108,8 +110,9 @@ class PCA_Distribution(SB3_Distribution):
|
||||
self.window = window
|
||||
self.epsilon = epsilon
|
||||
self.skip_conditioning = skip_conditioning
|
||||
self.temporal_gradient_emission = temporal_gradient_emission
|
||||
|
||||
self.base_noise = cast_to_Noise(Base_Noise, (1, action_dim))
|
||||
self.base_noise = cast_to_Noise(Base_Noise, (n_envs, action_dim))
|
||||
|
||||
# Premature optimization is the root of all evil
|
||||
self._build_conditioner()
|
||||
@ -128,7 +131,16 @@ class PCA_Distribution(SB3_Distribution):
|
||||
return self
|
||||
|
||||
def log_prob(self, actions: th.Tensor) -> th.Tensor:
|
||||
return sum_independent_dims(self.distribution.log_prob(actions.to(self.distribution.mean.device)))
|
||||
return self._log_prob(actions, self.distribution)
|
||||
|
||||
def conditioned_log_prob(self, actions: th.Tensor, trajectory: th.Tensor = None) -> th.Tensor:
|
||||
pi_mean, pi_std = self.distribution.mean.cpu(), self.distribution.scale.cpu()
|
||||
rho_mean, rho_std = self._conditioning_engine(trajectory, pi_mean, pi_std)
|
||||
dist = Normal(rho_mean, rho_std)
|
||||
return self._log_prob(dist)
|
||||
|
||||
def _log_prob(self, actions: th.Tensor, dist: Normal):
|
||||
return sum_independent_dims(dist.log_prob(actions.to(dist.mean.device)))
|
||||
|
||||
def entropy(self) -> th.Tensor:
|
||||
return sum_independent_dims(self.distribution.entropy())
|
||||
@ -164,7 +176,13 @@ class PCA_Distribution(SB3_Distribution):
|
||||
return self.par_strength not in [Par_Strength.SCALAR, Par_Strength.DIAG]
|
||||
|
||||
def _get_rigged(self, pi_mean, pi_std, rho_mean, rho_std, epsilon=None):
|
||||
# Ugly function to ensure, that the gradients flow as intended for each modus operandi
|
||||
if not self.temporal_gradient_emission or self.skip_conditioning:
|
||||
with th.no_grad():
|
||||
return self._get_emitting_rigged(pi_mean, pi_std, rho_mean, rho_std, epsilon=epsilon).detach()
|
||||
return self._get_emitting_rigged(pi_mean.detach(), pi_std.detach(), rho_mean, rho_std, epsilon=epsilon)
|
||||
|
||||
def _get_emitting_rigged(self, pi_mean, pi_std, rho_mean, rho_std, epsilon=None):
|
||||
if epsilon == None:
|
||||
epsilon = self.base_noise(pi_mean.shape)
|
||||
|
||||
@ -177,7 +195,7 @@ class PCA_Distribution(SB3_Distribution):
|
||||
|
||||
eta = Pi_mu * Delta + Pi_sigma * epsilon
|
||||
|
||||
return eta.detach()
|
||||
return eta
|
||||
|
||||
def _pad_and_cut_trajectory(self, traj, value=0):
|
||||
if traj.shape[-2] < self.window:
|
||||
@ -194,18 +212,15 @@ class PCA_Distribution(SB3_Distribution):
|
||||
return pi_mean, pi_std
|
||||
|
||||
traj = self._pad_and_cut_trajectory(trajectory)
|
||||
|
||||
# Numpy is fun
|
||||
y_np = np.append(np.swapaxes(traj, -1, -2), np.repeat(np.expand_dims(pi_mean, -1), traj.shape[0], 0), -1)
|
||||
y = th.cat((traj.transpose(-1, -2), pi_mean.unsqueeze(-1).unsqueeze(0).repeat(traj.shape[0], 1, traj.shape[-2])), dim=1)
|
||||
|
||||
with th.no_grad():
|
||||
conditioners = th.Tensor(self._adapt_conditioner(pi_std))
|
||||
y = th.Tensor(y_np)
|
||||
|
||||
S = th.cholesky_solve(self.Sig12.expand(conditioners.shape[:-1]).unsqueeze(-1), conditioners).squeeze(-1)
|
||||
|
||||
rho_mean = th.einsum('bai,bai->ba', S, y)
|
||||
rho_std = self.Sig22 - (S @ self.Sig12)
|
||||
rho_mean = th.einsum('bai,bai->ba', S, y)
|
||||
|
||||
return rho_mean, rho_std
|
||||
|
||||
@ -254,13 +269,13 @@ class PCA_Distribution(SB3_Distribution):
|
||||
return self.distribution.mean
|
||||
|
||||
def actions_from_params(
|
||||
self, mean: th.Tensor, std: th.Tensor, deterministic: bool = False
|
||||
self, mean: th.Tensor, std: th.Tensor, deterministic: bool = False, trajectory: th.Tensor = None
|
||||
) -> th.Tensor:
|
||||
self.proba_distribution(mean, std)
|
||||
return self.get_actions(deterministic=deterministic)
|
||||
return self.get_actions(deterministic=deterministic, trajectory=trajectory)
|
||||
|
||||
def log_prob_from_params(self, mean: th.Tensor, std: th.Tensor):
|
||||
actions = self.actions_from_params(mean, std)
|
||||
def log_prob_from_params(self, mean: th.Tensor, std: th.Tensor, trajectory: th.Tensor = None):
|
||||
actions = self.actions_from_params(mean, std, trajectory=trajectory)
|
||||
log_prob = self.log_prob(actions)
|
||||
return actions, log_prob
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user