Smashing bugs (dimension mismatch between Normal and

Independent/MultivariateNormal)
This commit is contained in:
Dominik Moritz Roth 2022-07-15 15:46:31 +02:00
parent ab557a8856
commit a86d19053d
4 changed files with 21 additions and 17 deletions

View File

@ -41,8 +41,8 @@ class ParametrizationType(Enum):
CHOL = 1
SPHERICAL_CHOL = 2
# Not (yet?) implemented:
#GIVENS = 3
#NNLN_EIGEN = 4
# GIVENS = 3
# NNLN_EIGEN = 4
class EnforcePositiveType(Enum):
@ -360,13 +360,15 @@ class CholNet(nn.Module):
raise Exception()
def _chol_from_flat(self, flat_chol):
chol = fill_triangular(flat_chol).expand(self._flat_chol_len, -1, -1)
# chol = fill_triangular(flat_chol).expand(self._flat_chol_len, -1, -1)
chol = fill_triangular(flat_chol)
return self._ensure_diagonal_positive(chol)
def _chol_from_flat_sphe_chol(self, flat_sphe_chol):
pos_flat_sphe_chol = self._ensure_positive_func(flat_sphe_chol)
sphe_chol = fill_triangular(pos_flat_sphe_chol).expand(
self._flat_chol_len, -1, -1)
# sphe_chol = fill_triangular(pos_flat_sphe_chol).expand(
# self._flat_chol_len, -1, -1)
sphe_chol = fill_triangular(pos_flat_sphe_chol)
chol = self._chol_from_sphe_chol(sphe_chol)
return chol

View File

@ -26,7 +26,7 @@ class GaussianRolloutBufferSamples(NamedTuple):
advantages: th.Tensor
returns: th.Tensor
means: th.Tensor
stds: th.Tensor
chols: th.Tensor
class GaussianRolloutBuffer(RolloutBuffer):
@ -56,7 +56,7 @@ class GaussianRolloutBuffer(RolloutBuffer):
def reset(self) -> None:
self.means = np.zeros(
(self.buffer_size, self.n_envs) + self.action_space.shape, dtype=np.float32)
self.stds = np.zeros(
self.chols = np.zeros(
(self.buffer_size, self.n_envs) + self.cov_shape, dtype=np.float32)
super().reset()
@ -69,7 +69,7 @@ class GaussianRolloutBuffer(RolloutBuffer):
value: th.Tensor,
log_prob: th.Tensor,
mean: th.Tensor,
std: th.Tensor,
chol: th.Tensor,
) -> None:
"""
:param obs: Observation
@ -80,8 +80,8 @@ class GaussianRolloutBuffer(RolloutBuffer):
following the current policy.
:param log_prob: log probability of the action
following the current policy.
:param mean: Foo
:param std: Bar
:param mean:
:param chol:
"""
if len(log_prob.shape) == 0:
@ -100,7 +100,7 @@ class GaussianRolloutBuffer(RolloutBuffer):
self.values[self.pos] = value.clone().cpu().numpy().flatten()
self.log_probs[self.pos] = log_prob.clone().cpu().numpy()
self.means[self.pos] = mean.clone().cpu().numpy()
self.stds[self.pos] = std.clone().cpu().numpy()
self.chols[self.pos] = chol.clone().cpu().numpy()
self.pos += 1
if self.pos == self.buffer_size:
self.full = True
@ -114,7 +114,8 @@ class GaussianRolloutBuffer(RolloutBuffer):
self.advantages[batch_inds].flatten(),
self.returns[batch_inds].flatten(),
self.means[batch_inds].reshape((len(batch_inds), -1)),
self.stds[batch_inds].reshape((len(batch_inds), -1)),
self.chols[batch_inds].reshape(
(len(batch_inds),) + self.cov_shape),
)
return GaussianRolloutBufferSamples(*tuple(map(self.to_torch, data)))
@ -181,7 +182,7 @@ class GaussianRolloutCollectorAuxclass():
obs_tensor = obs_as_tensor(self._last_obs, self.device)
actions, values, log_probs = self.policy(obs_tensor)
dist = self.policy.get_distribution(obs_tensor).distribution
mean, std = dist.mean, dist.stddev
mean, chol = get_mean_and_chol(dist)
actions = actions.cpu().numpy()
# Rescale and perform action
@ -223,7 +224,7 @@ class GaussianRolloutCollectorAuxclass():
rewards[idx] += self.gamma * terminal_value
rollout_buffer.add(self._last_obs, actions, rewards,
self._last_episode_starts, values, log_probs, mean, std)
self._last_episode_starts, values, log_probs, mean, chol)
self._last_obs = new_obs
self._last_episode_starts = dones

View File

@ -246,7 +246,7 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
p = pol._get_action_dist_from_latent(latent_pi)
p_dist = p.distribution
q_dist = new_dist_like(
p_dist, rollout_data.means, rollout_data.stds)
p_dist, rollout_data.means, rollout_data.chols)
proj_p = self.projection(p_dist, q_dist, self._global_steps)
if isinstance(p_dist, th.distributions.Normal):
# Normal uses a weird mapping from dimensions into batch_shape

View File

@ -26,8 +26,9 @@ def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=2_000_000, showRes=Tru
ppo = PPO(
MlpPolicy,
env,
policy_kwargs={'dist_kwargs': {'neural_strength': Strength.DIAG, 'cov_strength': Strength.DIAG, 'parameterization_type':
ParametrizationType.NONE, 'enforce_positive_type': EnforcePositiveType.ABS, 'prob_squashing_type': ProbSquashingType.NONE}},
projection=FrobeniusProjectionLayer(),
policy_kwargs={'dist_kwargs': {'neural_strength': Strength.FULL, 'cov_strength': Strength.FULL, 'parameterization_type':
ParametrizationType.CHOL, 'enforce_positive_type': EnforcePositiveType.ABS, 'prob_squashing_type': ProbSquashingType.NONE}},
verbose=0,
tensorboard_log=root_path+"/logs_tb/" +
env_name+"/ppo"+(['', '_sde'][use_sde])+"/",