Allow sampling noises with reduced then known batch size

This commit is contained in:
Dominik Moritz Roth 2024-03-09 14:02:50 +01:00
parent b0e2bc3a7a
commit 4485e558a8

View File

@ -17,10 +17,10 @@ class Colored_Noise():
self.reset(random_state=random_state)
def __call__(self, shape, latent: th.Tensor = None) -> th.Tensor:
assert shape == self.known_shape
assert shape == self.known_shape or (shape[1:] == self.known_shape[1:] and shape[0] <= self.known_shape[0])
sample = self.samples[:, self.index]
self.index = (self.index+1) % self.num_samples
return th.Tensor(sample).view(self.known_shape)
return th.Tensor(sample).view(self.known_shape)[:shape[0]]
def reset(self, random_state=None):
self.samples = cn.powerlaw_psd_gaussian(