Allow sampling noises with reduced then known batch size
This commit is contained in:
parent
b0e2bc3a7a
commit
4485e558a8
@ -17,10 +17,10 @@ class Colored_Noise():
|
||||
self.reset(random_state=random_state)
|
||||
|
||||
def __call__(self, shape, latent: th.Tensor = None) -> th.Tensor:
|
||||
assert shape == self.known_shape
|
||||
assert shape == self.known_shape or (shape[1:] == self.known_shape[1:] and shape[0] <= self.known_shape[0])
|
||||
sample = self.samples[:, self.index]
|
||||
self.index = (self.index+1) % self.num_samples
|
||||
return th.Tensor(sample).view(self.known_shape)
|
||||
return th.Tensor(sample).view(self.known_shape)[:shape[0]]
|
||||
|
||||
def reset(self, random_state=None):
|
||||
self.samples = cn.powerlaw_psd_gaussian(
|
||||
|
Loading…
Reference in New Issue
Block a user