Smashing bugs (dimension mismatch between Normal and

Independent/MultivariateNormal)
This commit is contained in:
Dominik Moritz Roth 2022-07-15 15:46:31 +02:00
parent ab557a8856
commit a86d19053d
4 changed files with 21 additions and 17 deletions

View File

@ -41,8 +41,8 @@ class ParametrizationType(Enum):
CHOL = 1 CHOL = 1
SPHERICAL_CHOL = 2 SPHERICAL_CHOL = 2
# Not (yet?) implemented: # Not (yet?) implemented:
#GIVENS = 3 # GIVENS = 3
#NNLN_EIGEN = 4 # NNLN_EIGEN = 4
class EnforcePositiveType(Enum): class EnforcePositiveType(Enum):
@ -360,13 +360,15 @@ class CholNet(nn.Module):
raise Exception() raise Exception()
def _chol_from_flat(self, flat_chol): def _chol_from_flat(self, flat_chol):
chol = fill_triangular(flat_chol).expand(self._flat_chol_len, -1, -1) # chol = fill_triangular(flat_chol).expand(self._flat_chol_len, -1, -1)
chol = fill_triangular(flat_chol)
return self._ensure_diagonal_positive(chol) return self._ensure_diagonal_positive(chol)
def _chol_from_flat_sphe_chol(self, flat_sphe_chol): def _chol_from_flat_sphe_chol(self, flat_sphe_chol):
pos_flat_sphe_chol = self._ensure_positive_func(flat_sphe_chol) pos_flat_sphe_chol = self._ensure_positive_func(flat_sphe_chol)
sphe_chol = fill_triangular(pos_flat_sphe_chol).expand( # sphe_chol = fill_triangular(pos_flat_sphe_chol).expand(
self._flat_chol_len, -1, -1) # self._flat_chol_len, -1, -1)
sphe_chol = fill_triangular(pos_flat_sphe_chol)
chol = self._chol_from_sphe_chol(sphe_chol) chol = self._chol_from_sphe_chol(sphe_chol)
return chol return chol

View File

@ -26,7 +26,7 @@ class GaussianRolloutBufferSamples(NamedTuple):
advantages: th.Tensor advantages: th.Tensor
returns: th.Tensor returns: th.Tensor
means: th.Tensor means: th.Tensor
stds: th.Tensor chols: th.Tensor
class GaussianRolloutBuffer(RolloutBuffer): class GaussianRolloutBuffer(RolloutBuffer):
@ -56,7 +56,7 @@ class GaussianRolloutBuffer(RolloutBuffer):
def reset(self) -> None: def reset(self) -> None:
self.means = np.zeros( self.means = np.zeros(
(self.buffer_size, self.n_envs) + self.action_space.shape, dtype=np.float32) (self.buffer_size, self.n_envs) + self.action_space.shape, dtype=np.float32)
self.stds = np.zeros( self.chols = np.zeros(
(self.buffer_size, self.n_envs) + self.cov_shape, dtype=np.float32) (self.buffer_size, self.n_envs) + self.cov_shape, dtype=np.float32)
super().reset() super().reset()
@ -69,7 +69,7 @@ class GaussianRolloutBuffer(RolloutBuffer):
value: th.Tensor, value: th.Tensor,
log_prob: th.Tensor, log_prob: th.Tensor,
mean: th.Tensor, mean: th.Tensor,
std: th.Tensor, chol: th.Tensor,
) -> None: ) -> None:
""" """
:param obs: Observation :param obs: Observation
@ -80,8 +80,8 @@ class GaussianRolloutBuffer(RolloutBuffer):
following the current policy. following the current policy.
:param log_prob: log probability of the action :param log_prob: log probability of the action
following the current policy. following the current policy.
:param mean: Foo :param mean:
:param std: Bar :param chol:
""" """
if len(log_prob.shape) == 0: if len(log_prob.shape) == 0:
@ -100,7 +100,7 @@ class GaussianRolloutBuffer(RolloutBuffer):
self.values[self.pos] = value.clone().cpu().numpy().flatten() self.values[self.pos] = value.clone().cpu().numpy().flatten()
self.log_probs[self.pos] = log_prob.clone().cpu().numpy() self.log_probs[self.pos] = log_prob.clone().cpu().numpy()
self.means[self.pos] = mean.clone().cpu().numpy() self.means[self.pos] = mean.clone().cpu().numpy()
self.stds[self.pos] = std.clone().cpu().numpy() self.chols[self.pos] = chol.clone().cpu().numpy()
self.pos += 1 self.pos += 1
if self.pos == self.buffer_size: if self.pos == self.buffer_size:
self.full = True self.full = True
@ -114,7 +114,8 @@ class GaussianRolloutBuffer(RolloutBuffer):
self.advantages[batch_inds].flatten(), self.advantages[batch_inds].flatten(),
self.returns[batch_inds].flatten(), self.returns[batch_inds].flatten(),
self.means[batch_inds].reshape((len(batch_inds), -1)), self.means[batch_inds].reshape((len(batch_inds), -1)),
self.stds[batch_inds].reshape((len(batch_inds), -1)), self.chols[batch_inds].reshape(
(len(batch_inds),) + self.cov_shape),
) )
return GaussianRolloutBufferSamples(*tuple(map(self.to_torch, data))) return GaussianRolloutBufferSamples(*tuple(map(self.to_torch, data)))
@ -181,7 +182,7 @@ class GaussianRolloutCollectorAuxclass():
obs_tensor = obs_as_tensor(self._last_obs, self.device) obs_tensor = obs_as_tensor(self._last_obs, self.device)
actions, values, log_probs = self.policy(obs_tensor) actions, values, log_probs = self.policy(obs_tensor)
dist = self.policy.get_distribution(obs_tensor).distribution dist = self.policy.get_distribution(obs_tensor).distribution
mean, std = dist.mean, dist.stddev mean, chol = get_mean_and_chol(dist)
actions = actions.cpu().numpy() actions = actions.cpu().numpy()
# Rescale and perform action # Rescale and perform action
@ -223,7 +224,7 @@ class GaussianRolloutCollectorAuxclass():
rewards[idx] += self.gamma * terminal_value rewards[idx] += self.gamma * terminal_value
rollout_buffer.add(self._last_obs, actions, rewards, rollout_buffer.add(self._last_obs, actions, rewards,
self._last_episode_starts, values, log_probs, mean, std) self._last_episode_starts, values, log_probs, mean, chol)
self._last_obs = new_obs self._last_obs = new_obs
self._last_episode_starts = dones self._last_episode_starts = dones

View File

@ -246,7 +246,7 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
p = pol._get_action_dist_from_latent(latent_pi) p = pol._get_action_dist_from_latent(latent_pi)
p_dist = p.distribution p_dist = p.distribution
q_dist = new_dist_like( q_dist = new_dist_like(
p_dist, rollout_data.means, rollout_data.stds) p_dist, rollout_data.means, rollout_data.chols)
proj_p = self.projection(p_dist, q_dist, self._global_steps) proj_p = self.projection(p_dist, q_dist, self._global_steps)
if isinstance(p_dist, th.distributions.Normal): if isinstance(p_dist, th.distributions.Normal):
# Normal uses a weird mapping from dimensions into batch_shape # Normal uses a weird mapping from dimensions into batch_shape

View File

@ -26,8 +26,9 @@ def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=2_000_000, showRes=Tru
ppo = PPO( ppo = PPO(
MlpPolicy, MlpPolicy,
env, env,
policy_kwargs={'dist_kwargs': {'neural_strength': Strength.DIAG, 'cov_strength': Strength.DIAG, 'parameterization_type': projection=FrobeniusProjectionLayer(),
ParametrizationType.NONE, 'enforce_positive_type': EnforcePositiveType.ABS, 'prob_squashing_type': ProbSquashingType.NONE}}, policy_kwargs={'dist_kwargs': {'neural_strength': Strength.FULL, 'cov_strength': Strength.FULL, 'parameterization_type':
ParametrizationType.CHOL, 'enforce_positive_type': EnforcePositiveType.ABS, 'prob_squashing_type': ProbSquashingType.NONE}},
verbose=0, verbose=0,
tensorboard_log=root_path+"/logs_tb/" + tensorboard_log=root_path+"/logs_tb/" +
env_name+"/ppo"+(['', '_sde'][use_sde])+"/", env_name+"/ppo"+(['', '_sde'][use_sde])+"/",