From bb9a76a07a5a571dacd958e064c16727803255b5 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Mon, 26 Jun 2023 16:38:11 +0200 Subject: [PATCH] Fallback to known_shape, when no shape is provided --- priorConditionedAnnealing/noise.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/priorConditionedAnnealing/noise.py b/priorConditionedAnnealing/noise.py index 2d5d8ea..9247e05 100644 --- a/priorConditionedAnnealing/noise.py +++ b/priorConditionedAnnealing/noise.py @@ -35,7 +35,9 @@ class White_Noise(): def __init__(self, known_shape=None): self.known_shape = known_shape - def __call__(self, shape, latent: th.Tensor = None) -> th.Tensor: + def __call__(self, shape=None, latent: th.Tensor = None) -> th.Tensor: + if shape == None: + shape = self.known_shape return th.Tensor(np.random.normal(0, 1, shape)) @@ -99,7 +101,9 @@ class Perlin_Noise(): self.normal_factor = 14/99 self.reset() - def __call__(self, shape): + def __call__(self, shape=None): + if shape == None: + shape = self.known_shape self.index += 1 noise = [self.noise([self.index*self.scale, self.magic*a]) / self.normal_factor for a in range(shape[-1])] @@ -114,16 +118,19 @@ class Harmonic_Perlin_Noise(): def __init__(self, known_shape=None, scale=0.1, octaves=8): self.known_shape = known_shape self.scale = scale + assert octaves >= 1 if type(octaves) in [int, float]: int_octaves = int(octaves) octaves_arr = [1/(i+1) for i in range(int_octaves)] - if type(octaves) == float: + if int_octaves != octaves: octaves_arr += [1/(int_octaves+2)*(octaves-int_octaves)] octaves_arr = np.array(octaves_arr) self.octaves = octaves_arr / np.linalg.norm(octaves_arr) self.reset() - def __call__(self, shape): + def __call__(self, shape=None): + if shape == None: + shape = self.known_shape harmonics = [noise(shape)*self.octaves[i] for i, noise in enumerate(self.noises)] return sum(harmonics)