Smashing bugs (dimension mismatch between Normal and
Independent/MultivariateNormal)
This commit is contained in:
parent
ab557a8856
commit
a86d19053d
@ -41,8 +41,8 @@ class ParametrizationType(Enum):
|
|||||||
CHOL = 1
|
CHOL = 1
|
||||||
SPHERICAL_CHOL = 2
|
SPHERICAL_CHOL = 2
|
||||||
# Not (yet?) implemented:
|
# Not (yet?) implemented:
|
||||||
#GIVENS = 3
|
# GIVENS = 3
|
||||||
#NNLN_EIGEN = 4
|
# NNLN_EIGEN = 4
|
||||||
|
|
||||||
|
|
||||||
class EnforcePositiveType(Enum):
|
class EnforcePositiveType(Enum):
|
||||||
@ -360,13 +360,15 @@ class CholNet(nn.Module):
|
|||||||
raise Exception()
|
raise Exception()
|
||||||
|
|
||||||
def _chol_from_flat(self, flat_chol):
|
def _chol_from_flat(self, flat_chol):
|
||||||
chol = fill_triangular(flat_chol).expand(self._flat_chol_len, -1, -1)
|
# chol = fill_triangular(flat_chol).expand(self._flat_chol_len, -1, -1)
|
||||||
|
chol = fill_triangular(flat_chol)
|
||||||
return self._ensure_diagonal_positive(chol)
|
return self._ensure_diagonal_positive(chol)
|
||||||
|
|
||||||
def _chol_from_flat_sphe_chol(self, flat_sphe_chol):
|
def _chol_from_flat_sphe_chol(self, flat_sphe_chol):
|
||||||
pos_flat_sphe_chol = self._ensure_positive_func(flat_sphe_chol)
|
pos_flat_sphe_chol = self._ensure_positive_func(flat_sphe_chol)
|
||||||
sphe_chol = fill_triangular(pos_flat_sphe_chol).expand(
|
# sphe_chol = fill_triangular(pos_flat_sphe_chol).expand(
|
||||||
self._flat_chol_len, -1, -1)
|
# self._flat_chol_len, -1, -1)
|
||||||
|
sphe_chol = fill_triangular(pos_flat_sphe_chol)
|
||||||
chol = self._chol_from_sphe_chol(sphe_chol)
|
chol = self._chol_from_sphe_chol(sphe_chol)
|
||||||
return chol
|
return chol
|
||||||
|
|
||||||
|
@ -26,7 +26,7 @@ class GaussianRolloutBufferSamples(NamedTuple):
|
|||||||
advantages: th.Tensor
|
advantages: th.Tensor
|
||||||
returns: th.Tensor
|
returns: th.Tensor
|
||||||
means: th.Tensor
|
means: th.Tensor
|
||||||
stds: th.Tensor
|
chols: th.Tensor
|
||||||
|
|
||||||
|
|
||||||
class GaussianRolloutBuffer(RolloutBuffer):
|
class GaussianRolloutBuffer(RolloutBuffer):
|
||||||
@ -56,7 +56,7 @@ class GaussianRolloutBuffer(RolloutBuffer):
|
|||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
self.means = np.zeros(
|
self.means = np.zeros(
|
||||||
(self.buffer_size, self.n_envs) + self.action_space.shape, dtype=np.float32)
|
(self.buffer_size, self.n_envs) + self.action_space.shape, dtype=np.float32)
|
||||||
self.stds = np.zeros(
|
self.chols = np.zeros(
|
||||||
(self.buffer_size, self.n_envs) + self.cov_shape, dtype=np.float32)
|
(self.buffer_size, self.n_envs) + self.cov_shape, dtype=np.float32)
|
||||||
super().reset()
|
super().reset()
|
||||||
|
|
||||||
@ -69,7 +69,7 @@ class GaussianRolloutBuffer(RolloutBuffer):
|
|||||||
value: th.Tensor,
|
value: th.Tensor,
|
||||||
log_prob: th.Tensor,
|
log_prob: th.Tensor,
|
||||||
mean: th.Tensor,
|
mean: th.Tensor,
|
||||||
std: th.Tensor,
|
chol: th.Tensor,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
:param obs: Observation
|
:param obs: Observation
|
||||||
@ -80,8 +80,8 @@ class GaussianRolloutBuffer(RolloutBuffer):
|
|||||||
following the current policy.
|
following the current policy.
|
||||||
:param log_prob: log probability of the action
|
:param log_prob: log probability of the action
|
||||||
following the current policy.
|
following the current policy.
|
||||||
:param mean: Foo
|
:param mean:
|
||||||
:param std: Bar
|
:param chol:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if len(log_prob.shape) == 0:
|
if len(log_prob.shape) == 0:
|
||||||
@ -100,7 +100,7 @@ class GaussianRolloutBuffer(RolloutBuffer):
|
|||||||
self.values[self.pos] = value.clone().cpu().numpy().flatten()
|
self.values[self.pos] = value.clone().cpu().numpy().flatten()
|
||||||
self.log_probs[self.pos] = log_prob.clone().cpu().numpy()
|
self.log_probs[self.pos] = log_prob.clone().cpu().numpy()
|
||||||
self.means[self.pos] = mean.clone().cpu().numpy()
|
self.means[self.pos] = mean.clone().cpu().numpy()
|
||||||
self.stds[self.pos] = std.clone().cpu().numpy()
|
self.chols[self.pos] = chol.clone().cpu().numpy()
|
||||||
self.pos += 1
|
self.pos += 1
|
||||||
if self.pos == self.buffer_size:
|
if self.pos == self.buffer_size:
|
||||||
self.full = True
|
self.full = True
|
||||||
@ -114,7 +114,8 @@ class GaussianRolloutBuffer(RolloutBuffer):
|
|||||||
self.advantages[batch_inds].flatten(),
|
self.advantages[batch_inds].flatten(),
|
||||||
self.returns[batch_inds].flatten(),
|
self.returns[batch_inds].flatten(),
|
||||||
self.means[batch_inds].reshape((len(batch_inds), -1)),
|
self.means[batch_inds].reshape((len(batch_inds), -1)),
|
||||||
self.stds[batch_inds].reshape((len(batch_inds), -1)),
|
self.chols[batch_inds].reshape(
|
||||||
|
(len(batch_inds),) + self.cov_shape),
|
||||||
)
|
)
|
||||||
return GaussianRolloutBufferSamples(*tuple(map(self.to_torch, data)))
|
return GaussianRolloutBufferSamples(*tuple(map(self.to_torch, data)))
|
||||||
|
|
||||||
@ -181,7 +182,7 @@ class GaussianRolloutCollectorAuxclass():
|
|||||||
obs_tensor = obs_as_tensor(self._last_obs, self.device)
|
obs_tensor = obs_as_tensor(self._last_obs, self.device)
|
||||||
actions, values, log_probs = self.policy(obs_tensor)
|
actions, values, log_probs = self.policy(obs_tensor)
|
||||||
dist = self.policy.get_distribution(obs_tensor).distribution
|
dist = self.policy.get_distribution(obs_tensor).distribution
|
||||||
mean, std = dist.mean, dist.stddev
|
mean, chol = get_mean_and_chol(dist)
|
||||||
actions = actions.cpu().numpy()
|
actions = actions.cpu().numpy()
|
||||||
|
|
||||||
# Rescale and perform action
|
# Rescale and perform action
|
||||||
@ -223,7 +224,7 @@ class GaussianRolloutCollectorAuxclass():
|
|||||||
rewards[idx] += self.gamma * terminal_value
|
rewards[idx] += self.gamma * terminal_value
|
||||||
|
|
||||||
rollout_buffer.add(self._last_obs, actions, rewards,
|
rollout_buffer.add(self._last_obs, actions, rewards,
|
||||||
self._last_episode_starts, values, log_probs, mean, std)
|
self._last_episode_starts, values, log_probs, mean, chol)
|
||||||
self._last_obs = new_obs
|
self._last_obs = new_obs
|
||||||
self._last_episode_starts = dones
|
self._last_episode_starts = dones
|
||||||
|
|
||||||
|
@ -246,7 +246,7 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
|
|||||||
p = pol._get_action_dist_from_latent(latent_pi)
|
p = pol._get_action_dist_from_latent(latent_pi)
|
||||||
p_dist = p.distribution
|
p_dist = p.distribution
|
||||||
q_dist = new_dist_like(
|
q_dist = new_dist_like(
|
||||||
p_dist, rollout_data.means, rollout_data.stds)
|
p_dist, rollout_data.means, rollout_data.chols)
|
||||||
proj_p = self.projection(p_dist, q_dist, self._global_steps)
|
proj_p = self.projection(p_dist, q_dist, self._global_steps)
|
||||||
if isinstance(p_dist, th.distributions.Normal):
|
if isinstance(p_dist, th.distributions.Normal):
|
||||||
# Normal uses a weird mapping from dimensions into batch_shape
|
# Normal uses a weird mapping from dimensions into batch_shape
|
||||||
|
5
test.py
5
test.py
@ -26,8 +26,9 @@ def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=2_000_000, showRes=Tru
|
|||||||
ppo = PPO(
|
ppo = PPO(
|
||||||
MlpPolicy,
|
MlpPolicy,
|
||||||
env,
|
env,
|
||||||
policy_kwargs={'dist_kwargs': {'neural_strength': Strength.DIAG, 'cov_strength': Strength.DIAG, 'parameterization_type':
|
projection=FrobeniusProjectionLayer(),
|
||||||
ParametrizationType.NONE, 'enforce_positive_type': EnforcePositiveType.ABS, 'prob_squashing_type': ProbSquashingType.NONE}},
|
policy_kwargs={'dist_kwargs': {'neural_strength': Strength.FULL, 'cov_strength': Strength.FULL, 'parameterization_type':
|
||||||
|
ParametrizationType.CHOL, 'enforce_positive_type': EnforcePositiveType.ABS, 'prob_squashing_type': ProbSquashingType.NONE}},
|
||||||
verbose=0,
|
verbose=0,
|
||||||
tensorboard_log=root_path+"/logs_tb/" +
|
tensorboard_log=root_path+"/logs_tb/" +
|
||||||
env_name+"/ppo"+(['', '_sde'][use_sde])+"/",
|
env_name+"/ppo"+(['', '_sde'][use_sde])+"/",
|
||||||
|
Loading…
Reference in New Issue
Block a user