Implement support for VecEnvs

This commit is contained in:
Dominik Moritz Roth 2024-03-09 13:41:27 +01:00
parent c68fd1635d
commit b0e2bc3a7a

View File

@ -89,6 +89,7 @@ class PCA_Distribution(SB3_Distribution):
def __init__(
self,
action_dim: int,
n_envs: int=1,
par_strength: Par_Strength = Par_Strength.CONT_DIAG,
kernel_func=kernel.rbf(),
init_std: float = 1,
@ -96,6 +97,7 @@ class PCA_Distribution(SB3_Distribution):
window: int = 64,
epsilon: float = 1e-6,
skip_conditioning: bool = False,
temporal_gradient_emission: bool = False,
Base_Noise=noise.White_Noise,
):
super().__init__()
@ -108,8 +110,9 @@ class PCA_Distribution(SB3_Distribution):
self.window = window
self.epsilon = epsilon
self.skip_conditioning = skip_conditioning
self.temporal_gradient_emission = temporal_gradient_emission
self.base_noise = cast_to_Noise(Base_Noise, (1, action_dim))
self.base_noise = cast_to_Noise(Base_Noise, (n_envs, action_dim))
# Premature optimization is the root of all evil
self._build_conditioner()
@ -128,7 +131,16 @@ class PCA_Distribution(SB3_Distribution):
return self
def log_prob(self, actions: th.Tensor) -> th.Tensor:
return sum_independent_dims(self.distribution.log_prob(actions.to(self.distribution.mean.device)))
return self._log_prob(actions, self.distribution)
def conditioned_log_prob(self, actions: th.Tensor, trajectory: th.Tensor = None) -> th.Tensor:
pi_mean, pi_std = self.distribution.mean.cpu(), self.distribution.scale.cpu()
rho_mean, rho_std = self._conditioning_engine(trajectory, pi_mean, pi_std)
dist = Normal(rho_mean, rho_std)
return self._log_prob(dist)
def _log_prob(self, actions: th.Tensor, dist: Normal):
return sum_independent_dims(dist.log_prob(actions.to(dist.mean.device)))
def entropy(self) -> th.Tensor:
return sum_independent_dims(self.distribution.entropy())
@ -164,7 +176,13 @@ class PCA_Distribution(SB3_Distribution):
return self.par_strength not in [Par_Strength.SCALAR, Par_Strength.DIAG]
def _get_rigged(self, pi_mean, pi_std, rho_mean, rho_std, epsilon=None):
# Ugly function to ensure, that the gradients flow as intended for each modus operandi
if not self.temporal_gradient_emission or self.skip_conditioning:
with th.no_grad():
return self._get_emitting_rigged(pi_mean, pi_std, rho_mean, rho_std, epsilon=epsilon).detach()
return self._get_emitting_rigged(pi_mean.detach(), pi_std.detach(), rho_mean, rho_std, epsilon=epsilon)
def _get_emitting_rigged(self, pi_mean, pi_std, rho_mean, rho_std, epsilon=None):
if epsilon == None:
epsilon = self.base_noise(pi_mean.shape)
@ -177,7 +195,7 @@ class PCA_Distribution(SB3_Distribution):
eta = Pi_mu * Delta + Pi_sigma * epsilon
return eta.detach()
return eta
def _pad_and_cut_trajectory(self, traj, value=0):
if traj.shape[-2] < self.window:
@ -194,18 +212,15 @@ class PCA_Distribution(SB3_Distribution):
return pi_mean, pi_std
traj = self._pad_and_cut_trajectory(trajectory)
# Numpy is fun
y_np = np.append(np.swapaxes(traj, -1, -2), np.repeat(np.expand_dims(pi_mean, -1), traj.shape[0], 0), -1)
y = th.cat((traj.transpose(-1, -2), pi_mean.unsqueeze(-1).unsqueeze(0).repeat(traj.shape[0], 1, traj.shape[-2])), dim=1)
with th.no_grad():
conditioners = th.Tensor(self._adapt_conditioner(pi_std))
y = th.Tensor(y_np)
S = th.cholesky_solve(self.Sig12.expand(conditioners.shape[:-1]).unsqueeze(-1), conditioners).squeeze(-1)
rho_mean = th.einsum('bai,bai->ba', S, y)
rho_std = self.Sig22 - (S @ self.Sig12)
rho_mean = th.einsum('bai,bai->ba', S, y)
return rho_mean, rho_std
@ -254,13 +269,13 @@ class PCA_Distribution(SB3_Distribution):
return self.distribution.mean
def actions_from_params(
self, mean: th.Tensor, std: th.Tensor, deterministic: bool = False
self, mean: th.Tensor, std: th.Tensor, deterministic: bool = False, trajectory: th.Tensor = None
) -> th.Tensor:
self.proba_distribution(mean, std)
return self.get_actions(deterministic=deterministic)
return self.get_actions(deterministic=deterministic, trajectory=trajectory)
def log_prob_from_params(self, mean: th.Tensor, std: th.Tensor):
actions = self.actions_from_params(mean, std)
def log_prob_from_params(self, mean: th.Tensor, std: th.Tensor, trajectory: th.Tensor = None):
actions = self.actions_from_params(mean, std, trajectory=trajectory)
log_prob = self.log_prob(actions)
return actions, log_prob