Minor bug fixes
This commit is contained in:
parent
59adde5bd5
commit
dcde2150ac
@ -15,7 +15,7 @@ class Colored_Noise():
|
||||
self.reset(random_state=random_state)
|
||||
|
||||
def __call__(self, shape, latent: th.Tensor = None) -> th.Tensor:
|
||||
assert shape == self.shape
|
||||
assert shape == self.knonw_shape
|
||||
sample = self.samples[:, self.index]
|
||||
self.index = (self.index+1) % self.num_samples
|
||||
return sample
|
||||
|
@ -145,6 +145,7 @@ class PCA_Distribution(SB3_Distribution):
|
||||
return self.sample(traj=trajectory)
|
||||
|
||||
def sample(self, traj: th.Tensor, f_sigma: int = 1, epsilon=None) -> th.Tensor:
|
||||
assert self.skip_conditioning or type(traj) != type(None), 'A past trajectory has to be supplied if conditinoning is performed'
|
||||
pi_mean, pi_std = self.distribution.mean.cpu(), self.distribution.scale.cpu()
|
||||
rho_mean, rho_std = self._conditioning_engine(traj, pi_mean, pi_std)
|
||||
rho_std *= f_sigma
|
||||
@ -157,6 +158,8 @@ class PCA_Distribution(SB3_Distribution):
|
||||
return actions
|
||||
|
||||
def is_contextual(self):
|
||||
return True # TODO: Remove, when bug for non-contextual is fixed
|
||||
# Always returning True will merely waste cpu cycles
|
||||
return self.par_strength not in [Par_Strength.SCALAR, Par_Strength.DIAG]
|
||||
|
||||
def _get_rigged(self, pi_mean, pi_std, rho_mean, rho_std, epsilon=None):
|
||||
@ -177,6 +180,10 @@ class PCA_Distribution(SB3_Distribution):
|
||||
|
||||
def _pad_and_cut_trajectory(self, traj, value=0):
|
||||
if traj.shape[-2] < self.window:
|
||||
if traj.shape[-2] == 0:
|
||||
shape = list(traj.shape)
|
||||
shape[-2] = 1
|
||||
traj = th.ones(shape)*value
|
||||
missing = self.window - traj.shape[-2]
|
||||
return F.pad(input=traj, pad=(0, 0, missing, 0, 0, 0), value=value)
|
||||
return traj[:, -self.window:, :]
|
||||
@ -186,16 +193,15 @@ class PCA_Distribution(SB3_Distribution):
|
||||
return pi_mean, pi_std
|
||||
|
||||
traj = self._pad_and_cut_trajectory(trajectory)
|
||||
|
||||
# Numpy is fun
|
||||
y_np = np.append(np.swapaxes(traj, -1, -2),
|
||||
np.repeat(np.expand_dims(pi_mean, -1), traj.shape[0], 0), -1)
|
||||
y_np = np.append(np.swapaxes(traj, -1, -2), np.repeat(np.expand_dims(pi_mean, -1), traj.shape[0], 0), -1)
|
||||
|
||||
with th.no_grad():
|
||||
conditioners = th.Tensor(self._adapt_conditioner(pi_std))
|
||||
y = th.Tensor(y_np)
|
||||
|
||||
S = th.cholesky_solve(self.Sig12.expand(
|
||||
conditioners.shape[:-1]).unsqueeze(-1), conditioners).squeeze(-1)
|
||||
S = th.cholesky_solve(self.Sig12.expand(conditioners.shape[:-1]).unsqueeze(-1), conditioners).squeeze(-1)
|
||||
|
||||
rho_mean = th.einsum('bai,bai->ba', S, y)
|
||||
rho_std = self.Sig22 - (S @ self.Sig12)
|
||||
|
Loading…
Reference in New Issue
Block a user