From 47815f8a5fb39dfc347d044a7b83322113a08fe9 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Mon, 22 May 2023 17:58:50 +0200 Subject: [PATCH] Fixed bug with Colored Noise returning wrong shapes --- priorConditionedAnnealing/noise.py | 5 +++-- priorConditionedAnnealing/pca.py | 4 +--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/priorConditionedAnnealing/noise.py b/priorConditionedAnnealing/noise.py index 230711d..e6db249 100644 --- a/priorConditionedAnnealing/noise.py +++ b/priorConditionedAnnealing/noise.py @@ -9,6 +9,7 @@ class Colored_Noise(): def __init__(self, known_shape=None, beta=1, num_samples=2**14, random_state=None): assert known_shape, 'known_shape need to be defined for Colored Noise' self.known_shape = known_shape + self.compact_shape = np.prod(list(known_shape)) self.beta = beta self.num_samples = num_samples # Actually very cheap... self.index = 0 @@ -18,11 +19,11 @@ class Colored_Noise(): assert shape == self.known_shape sample = self.samples[:, self.index] self.index = (self.index+1) % self.num_samples - return th.Tensor(sample) + return th.Tensor(sample).view(self.known_shape) def reset(self, random_state=None): self.samples = cn.powerlaw_psd_gaussian( - self.beta, self.known_shape + (self.num_samples,), random_state=random_state) + self.beta, (self.compact_shape, self.num_samples), random_state=random_state) class Pink_Noise(Colored_Noise): diff --git a/priorConditionedAnnealing/pca.py b/priorConditionedAnnealing/pca.py index a1dfaa9..47a278a 100644 --- a/priorConditionedAnnealing/pca.py +++ b/priorConditionedAnnealing/pca.py @@ -108,9 +108,6 @@ class PCA_Distribution(SB3_Distribution): self.base_noise = cast_to_Noise(Base_Noise, (1, action_dim)) - if not isinstance(self.base_noise, noise.White_Noise): - print('[!] Non-White Noise was not yet tested!') - # Premature optimization is the root of all evil self._build_conditioner() # *Optimizes it anyways* @@ -154,6 +151,7 @@ class PCA_Distribution(SB3_Distribution): epsilon) # reparameterization with rigged samples actions = pi_mean + pi_std * eta + self.gaussian_actions = actions return actions