From 4485e558a83eccb5e4fb459b0a22ec53a35a5a29 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sat, 9 Mar 2024 14:02:50 +0100 Subject: [PATCH] Allow sampling noises with reduced then known batch size --- priorConditionedAnnealing/noise.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/priorConditionedAnnealing/noise.py b/priorConditionedAnnealing/noise.py index 9f7978f..df9e373 100644 --- a/priorConditionedAnnealing/noise.py +++ b/priorConditionedAnnealing/noise.py @@ -17,10 +17,10 @@ class Colored_Noise(): self.reset(random_state=random_state) def __call__(self, shape, latent: th.Tensor = None) -> th.Tensor: - assert shape == self.known_shape + assert shape == self.known_shape or (shape[1:] == self.known_shape[1:] and shape[0] <= self.known_shape[0]) sample = self.samples[:, self.index] self.index = (self.index+1) % self.num_samples - return th.Tensor(sample).view(self.known_shape) + return th.Tensor(sample).view(self.known_shape)[:shape[0]] def reset(self, random_state=None): self.samples = cn.powerlaw_psd_gaussian(