diff --git a/priorConditionedAnnealing/noise.py b/priorConditionedAnnealing/noise.py index 9f7978f..df9e373 100644 --- a/priorConditionedAnnealing/noise.py +++ b/priorConditionedAnnealing/noise.py @@ -17,10 +17,10 @@ class Colored_Noise(): self.reset(random_state=random_state) def __call__(self, shape, latent: th.Tensor = None) -> th.Tensor: - assert shape == self.known_shape + assert shape == self.known_shape or (shape[1:] == self.known_shape[1:] and shape[0] <= self.known_shape[0]) sample = self.samples[:, self.index] self.index = (self.index+1) % self.num_samples - return th.Tensor(sample).view(self.known_shape) + return th.Tensor(sample).view(self.known_shape)[:shape[0]] def reset(self, random_state=None): self.samples = cn.powerlaw_psd_gaussian(