Allow sampling noises with reduced then known batch size
This commit is contained in:
parent
b0e2bc3a7a
commit
4485e558a8
@ -17,10 +17,10 @@ class Colored_Noise():
|
|||||||
self.reset(random_state=random_state)
|
self.reset(random_state=random_state)
|
||||||
|
|
||||||
def __call__(self, shape, latent: th.Tensor = None) -> th.Tensor:
|
def __call__(self, shape, latent: th.Tensor = None) -> th.Tensor:
|
||||||
assert shape == self.known_shape
|
assert shape == self.known_shape or (shape[1:] == self.known_shape[1:] and shape[0] <= self.known_shape[0])
|
||||||
sample = self.samples[:, self.index]
|
sample = self.samples[:, self.index]
|
||||||
self.index = (self.index+1) % self.num_samples
|
self.index = (self.index+1) % self.num_samples
|
||||||
return th.Tensor(sample).view(self.known_shape)
|
return th.Tensor(sample).view(self.known_shape)[:shape[0]]
|
||||||
|
|
||||||
def reset(self, random_state=None):
|
def reset(self, random_state=None):
|
||||||
self.samples = cn.powerlaw_psd_gaussian(
|
self.samples = cn.powerlaw_psd_gaussian(
|
||||||
|
Loading…
Reference in New Issue
Block a user