Fallback to known_shape, when no shape is provided

This commit is contained in:
Dominik Moritz Roth 2023-06-26 16:38:11 +02:00
parent bec6a5ffcd
commit bb9a76a07a

View File

@ -35,7 +35,9 @@ class White_Noise():
def __init__(self, known_shape=None): def __init__(self, known_shape=None):
self.known_shape = known_shape self.known_shape = known_shape
def __call__(self, shape, latent: th.Tensor = None) -> th.Tensor: def __call__(self, shape=None, latent: th.Tensor = None) -> th.Tensor:
if shape == None:
shape = self.known_shape
return th.Tensor(np.random.normal(0, 1, shape)) return th.Tensor(np.random.normal(0, 1, shape))
@ -99,7 +101,9 @@ class Perlin_Noise():
self.normal_factor = 14/99 self.normal_factor = 14/99
self.reset() self.reset()
def __call__(self, shape): def __call__(self, shape=None):
if shape == None:
shape = self.known_shape
self.index += 1 self.index += 1
noise = [self.noise([self.index*self.scale, self.magic*a]) / self.normal_factor noise = [self.noise([self.index*self.scale, self.magic*a]) / self.normal_factor
for a in range(shape[-1])] for a in range(shape[-1])]
@ -114,16 +118,19 @@ class Harmonic_Perlin_Noise():
def __init__(self, known_shape=None, scale=0.1, octaves=8): def __init__(self, known_shape=None, scale=0.1, octaves=8):
self.known_shape = known_shape self.known_shape = known_shape
self.scale = scale self.scale = scale
assert octaves >= 1
if type(octaves) in [int, float]: if type(octaves) in [int, float]:
int_octaves = int(octaves) int_octaves = int(octaves)
octaves_arr = [1/(i+1) for i in range(int_octaves)] octaves_arr = [1/(i+1) for i in range(int_octaves)]
if type(octaves) == float: if int_octaves != octaves:
octaves_arr += [1/(int_octaves+2)*(octaves-int_octaves)] octaves_arr += [1/(int_octaves+2)*(octaves-int_octaves)]
octaves_arr = np.array(octaves_arr) octaves_arr = np.array(octaves_arr)
self.octaves = octaves_arr / np.linalg.norm(octaves_arr) self.octaves = octaves_arr / np.linalg.norm(octaves_arr)
self.reset() self.reset()
def __call__(self, shape): def __call__(self, shape=None):
if shape == None:
shape = self.known_shape
harmonics = [noise(shape)*self.octaves[i] for i, noise in enumerate(self.noises)] harmonics = [noise(shape)*self.octaves[i] for i, noise in enumerate(self.noises)]
return sum(harmonics) return sum(harmonics)