Minor bug fixes
This commit is contained in:
parent
59adde5bd5
commit
dcde2150ac
@ -15,7 +15,7 @@ class Colored_Noise():
|
|||||||
self.reset(random_state=random_state)
|
self.reset(random_state=random_state)
|
||||||
|
|
||||||
def __call__(self, shape, latent: th.Tensor = None) -> th.Tensor:
|
def __call__(self, shape, latent: th.Tensor = None) -> th.Tensor:
|
||||||
assert shape == self.shape
|
assert shape == self.knonw_shape
|
||||||
sample = self.samples[:, self.index]
|
sample = self.samples[:, self.index]
|
||||||
self.index = (self.index+1) % self.num_samples
|
self.index = (self.index+1) % self.num_samples
|
||||||
return sample
|
return sample
|
||||||
|
@ -145,6 +145,7 @@ class PCA_Distribution(SB3_Distribution):
|
|||||||
return self.sample(traj=trajectory)
|
return self.sample(traj=trajectory)
|
||||||
|
|
||||||
def sample(self, traj: th.Tensor, f_sigma: int = 1, epsilon=None) -> th.Tensor:
|
def sample(self, traj: th.Tensor, f_sigma: int = 1, epsilon=None) -> th.Tensor:
|
||||||
|
assert self.skip_conditioning or type(traj) != type(None), 'A past trajectory has to be supplied if conditinoning is performed'
|
||||||
pi_mean, pi_std = self.distribution.mean.cpu(), self.distribution.scale.cpu()
|
pi_mean, pi_std = self.distribution.mean.cpu(), self.distribution.scale.cpu()
|
||||||
rho_mean, rho_std = self._conditioning_engine(traj, pi_mean, pi_std)
|
rho_mean, rho_std = self._conditioning_engine(traj, pi_mean, pi_std)
|
||||||
rho_std *= f_sigma
|
rho_std *= f_sigma
|
||||||
@ -157,6 +158,8 @@ class PCA_Distribution(SB3_Distribution):
|
|||||||
return actions
|
return actions
|
||||||
|
|
||||||
def is_contextual(self):
|
def is_contextual(self):
|
||||||
|
return True # TODO: Remove, when bug for non-contextual is fixed
|
||||||
|
# Always returning True will merely waste cpu cycles
|
||||||
return self.par_strength not in [Par_Strength.SCALAR, Par_Strength.DIAG]
|
return self.par_strength not in [Par_Strength.SCALAR, Par_Strength.DIAG]
|
||||||
|
|
||||||
def _get_rigged(self, pi_mean, pi_std, rho_mean, rho_std, epsilon=None):
|
def _get_rigged(self, pi_mean, pi_std, rho_mean, rho_std, epsilon=None):
|
||||||
@ -177,6 +180,10 @@ class PCA_Distribution(SB3_Distribution):
|
|||||||
|
|
||||||
def _pad_and_cut_trajectory(self, traj, value=0):
|
def _pad_and_cut_trajectory(self, traj, value=0):
|
||||||
if traj.shape[-2] < self.window:
|
if traj.shape[-2] < self.window:
|
||||||
|
if traj.shape[-2] == 0:
|
||||||
|
shape = list(traj.shape)
|
||||||
|
shape[-2] = 1
|
||||||
|
traj = th.ones(shape)*value
|
||||||
missing = self.window - traj.shape[-2]
|
missing = self.window - traj.shape[-2]
|
||||||
return F.pad(input=traj, pad=(0, 0, missing, 0, 0, 0), value=value)
|
return F.pad(input=traj, pad=(0, 0, missing, 0, 0, 0), value=value)
|
||||||
return traj[:, -self.window:, :]
|
return traj[:, -self.window:, :]
|
||||||
@ -186,16 +193,15 @@ class PCA_Distribution(SB3_Distribution):
|
|||||||
return pi_mean, pi_std
|
return pi_mean, pi_std
|
||||||
|
|
||||||
traj = self._pad_and_cut_trajectory(trajectory)
|
traj = self._pad_and_cut_trajectory(trajectory)
|
||||||
|
|
||||||
# Numpy is fun
|
# Numpy is fun
|
||||||
y_np = np.append(np.swapaxes(traj, -1, -2),
|
y_np = np.append(np.swapaxes(traj, -1, -2), np.repeat(np.expand_dims(pi_mean, -1), traj.shape[0], 0), -1)
|
||||||
np.repeat(np.expand_dims(pi_mean, -1), traj.shape[0], 0), -1)
|
|
||||||
|
|
||||||
with th.no_grad():
|
with th.no_grad():
|
||||||
conditioners = th.Tensor(self._adapt_conditioner(pi_std))
|
conditioners = th.Tensor(self._adapt_conditioner(pi_std))
|
||||||
y = th.Tensor(y_np)
|
y = th.Tensor(y_np)
|
||||||
|
|
||||||
S = th.cholesky_solve(self.Sig12.expand(
|
S = th.cholesky_solve(self.Sig12.expand(conditioners.shape[:-1]).unsqueeze(-1), conditioners).squeeze(-1)
|
||||||
conditioners.shape[:-1]).unsqueeze(-1), conditioners).squeeze(-1)
|
|
||||||
|
|
||||||
rho_mean = th.einsum('bai,bai->ba', S, y)
|
rho_mean = th.einsum('bai,bai->ba', S, y)
|
||||||
rho_std = self.Sig22 - (S @ self.Sig12)
|
rho_std = self.Sig22 - (S @ self.Sig12)
|
||||||
|
Loading…
Reference in New Issue
Block a user