Some additions to make configuration from config files (cw2) easier

This commit is contained in:
Dominik Moritz Roth 2023-05-21 16:18:42 +02:00
parent 539f586751
commit 224fc82a48
2 changed files with 45 additions and 5 deletions

View File

@ -6,11 +6,11 @@ from torch.distributions import Normal
class Colored_Noise(): class Colored_Noise():
def __init__(self, known_shape=None, beta=1, num_samples=2**16, random_state=None): def __init__(self, known_shape=None, beta=1, num_samples=2**14, random_state=None):
assert known_shape, 'known_shape need to be defined for Colored Noise' assert known_shape, 'known_shape need to be defined for Colored Noise'
self.known_shape = known_shape self.known_shape = known_shape
self.beta = beta self.beta = beta
self.num_samples = num_samples self.num_samples = num_samples # Actually very cheap...
self.index = 0 self.index = 0
self.reset(random_state=random_state) self.reset(random_state=random_state)
@ -25,6 +25,11 @@ class Colored_Noise():
self.beta, self.shape + (self.num_samples,), random_state=random_state) self.beta, self.shape + (self.num_samples,), random_state=random_state)
class Pink_Noise(Colored_Noise):
def __init__(self, known_shape=None, num_samples=2**14, random_state=None):
super().__init__(known_shape=known_shape, beta=1, num_samples=num_samples, random_state=random_state)
class White_Noise(): class White_Noise():
def __init__(self, known_shape=None): def __init__(self, known_shape=None):
self.known_shape = known_shape self.known_shape = known_shape
@ -36,6 +41,8 @@ class White_Noise():
def get_colored_noise(beta, known_shape=None): def get_colored_noise(beta, known_shape=None):
if beta == 0: if beta == 0:
return White_Noise(known_shape) return White_Noise(known_shape)
elif beta == 1:
return Pink_Noise(known_shape)
else: else:
return Colored_Noise(known_shape, beta=beta) return Colored_Noise(known_shape, beta=beta)
@ -86,14 +93,14 @@ class Perlin_Noise():
self.known_shape = known_shape self.known_shape = known_shape
self.scale = scale self.scale = scale
self.octaves = octaves self.octaves = octaves
self.magic = 3.14159 # Axis offset self.magic = 0.141592653589 # Axis offset, should be (kinda) irrational
# We want to genrate samples, that approx ~N(0,1) # We want to genrate samples, that approx ~N(0,1)
self.normal_factor = 0.0471 self.normal_factor = 0.0471
self.reset() self.reset()
def __call__(self, shape): def __call__(self, shape):
self.index += 1 self.index += 1
return [self.noise([self.index*self.scale, self.magic*a]) / self.normal_factor return [self.noise([self.index*self.scale, self.magic*(a+1)]) / self.normal_factor
for a in range(self.shape[-1])] for a in range(self.shape[-1])]
def reset(self): def reset(self):

View File

@ -45,6 +45,18 @@ class Avaible_Kernel_Funcs(Enum):
return [kernel.rbf, kernel.se, kernel.brown, kernel.pink][self.value] return [kernel.rbf, kernel.se, kernel.brown, kernel.pink][self.value]
class Avaible_Noise_Funcs(Enum):
WHITE = 0
PINK = 1
COLOR = 2
PERLIN = 3
SDE = 4
def get_func(self):
# stil aaaaaaaa
return [noise.White_Noise, noise.Pink_Noise, noise.Colored_Noise, noise.Perlin_Noise, noise.SDE_Noise][self.value]
def cast_to_enum(inp, Class): def cast_to_enum(inp, Class):
if isinstance(inp, Enum): if isinstance(inp, Enum):
return inp return inp
@ -61,6 +73,15 @@ def cast_to_kernel(inp):
return Avaible_Kernel_Funcs[func].get_func()(*pars) return Avaible_Kernel_Funcs[func].get_func()(*pars)
def cast_to_Noise(Inp, known_shape):
if callable(Inp): # TODO: Allow instantiated?
return Inp(known_shape)
else:
func, *pars = Inp.split('_')
pars = [float(par) for par in pars]
return Avaible_Noise_Funcs[func].get_func()(known_shape, *pars)
class PCA_Distribution(SB3_Distribution): class PCA_Distribution(SB3_Distribution):
def __init__( def __init__(
self, self,
@ -85,7 +106,7 @@ class PCA_Distribution(SB3_Distribution):
self.epsilon = epsilon self.epsilon = epsilon
self.skip_conditioning = skip_conditioning self.skip_conditioning = skip_conditioning
self.base_noise = Base_Noise((1, action_dim)) self.base_noise = cast_to_Noise(Base_Noise, (1, action_dim))
if not isinstance(self.base_noise, noise.White_Noise): if not isinstance(self.base_noise, noise.White_Noise):
print('[!] Non-White Noise was not yet tested!') print('[!] Non-White Noise was not yet tested!')
@ -93,6 +114,7 @@ class PCA_Distribution(SB3_Distribution):
# Premature optimization is the root of all evil # Premature optimization is the root of all evil
self._build_conditioner() self._build_conditioner()
# *Optimizes it anyways* # *Optimizes it anyways*
print('[i] PCA-Distribution initialized')
def proba_distribution_net(self, latent_dim: int): def proba_distribution_net(self, latent_dim: int):
mu_net = nn.Linear(latent_dim, self.action_dim) mu_net = nn.Linear(latent_dim, self.action_dim)
@ -113,6 +135,17 @@ class PCA_Distribution(SB3_Distribution):
def entropy(self) -> th.Tensor: def entropy(self) -> th.Tensor:
return sum_independent_dims(self.distribution.entropy()) return sum_independent_dims(self.distribution.entropy())
def get_actions(self, deterministic: bool = False, trajectory: th.Tensor = None) -> th.Tensor:
"""
Return actions according to the probability distribution.
:param deterministic:
:return:
"""
if deterministic:
return self.mode()
return self.sample(trajectory=trajectory)
def sample(self, traj: th.Tensor, f_sigma: int = 1, epsilon=None) -> th.Tensor: def sample(self, traj: th.Tensor, f_sigma: int = 1, epsilon=None) -> th.Tensor:
pi_mean, pi_std = self.distribution.mean, self.distribution.scale pi_mean, pi_std = self.distribution.mean, self.distribution.scale
rho_mean, rho_std = self._conditioning_engine(traj, pi_mean, pi_std) rho_mean, rho_std = self._conditioning_engine(traj, pi_mean, pi_std)