Fixed bug with Colored Noise returning wrong shapes
This commit is contained in:
parent
55cbd734c0
commit
47815f8a5f
@ -9,6 +9,7 @@ class Colored_Noise():
|
|||||||
def __init__(self, known_shape=None, beta=1, num_samples=2**14, random_state=None):
|
def __init__(self, known_shape=None, beta=1, num_samples=2**14, random_state=None):
|
||||||
assert known_shape, 'known_shape need to be defined for Colored Noise'
|
assert known_shape, 'known_shape need to be defined for Colored Noise'
|
||||||
self.known_shape = known_shape
|
self.known_shape = known_shape
|
||||||
|
self.compact_shape = np.prod(list(known_shape))
|
||||||
self.beta = beta
|
self.beta = beta
|
||||||
self.num_samples = num_samples # Actually very cheap...
|
self.num_samples = num_samples # Actually very cheap...
|
||||||
self.index = 0
|
self.index = 0
|
||||||
@ -18,11 +19,11 @@ class Colored_Noise():
|
|||||||
assert shape == self.known_shape
|
assert shape == self.known_shape
|
||||||
sample = self.samples[:, self.index]
|
sample = self.samples[:, self.index]
|
||||||
self.index = (self.index+1) % self.num_samples
|
self.index = (self.index+1) % self.num_samples
|
||||||
return th.Tensor(sample)
|
return th.Tensor(sample).view(self.known_shape)
|
||||||
|
|
||||||
def reset(self, random_state=None):
|
def reset(self, random_state=None):
|
||||||
self.samples = cn.powerlaw_psd_gaussian(
|
self.samples = cn.powerlaw_psd_gaussian(
|
||||||
self.beta, self.known_shape + (self.num_samples,), random_state=random_state)
|
self.beta, (self.compact_shape, self.num_samples), random_state=random_state)
|
||||||
|
|
||||||
|
|
||||||
class Pink_Noise(Colored_Noise):
|
class Pink_Noise(Colored_Noise):
|
||||||
|
@ -108,9 +108,6 @@ class PCA_Distribution(SB3_Distribution):
|
|||||||
|
|
||||||
self.base_noise = cast_to_Noise(Base_Noise, (1, action_dim))
|
self.base_noise = cast_to_Noise(Base_Noise, (1, action_dim))
|
||||||
|
|
||||||
if not isinstance(self.base_noise, noise.White_Noise):
|
|
||||||
print('[!] Non-White Noise was not yet tested!')
|
|
||||||
|
|
||||||
# Premature optimization is the root of all evil
|
# Premature optimization is the root of all evil
|
||||||
self._build_conditioner()
|
self._build_conditioner()
|
||||||
# *Optimizes it anyways*
|
# *Optimizes it anyways*
|
||||||
@ -154,6 +151,7 @@ class PCA_Distribution(SB3_Distribution):
|
|||||||
epsilon)
|
epsilon)
|
||||||
# reparameterization with rigged samples
|
# reparameterization with rigged samples
|
||||||
actions = pi_mean + pi_std * eta
|
actions = pi_mean + pi_std * eta
|
||||||
|
|
||||||
self.gaussian_actions = actions
|
self.gaussian_actions = actions
|
||||||
return actions
|
return actions
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user