diff --git a/metastable_baselines/distributions/distributions.py b/metastable_baselines/distributions/distributions.py index a67014e..421926f 100644 --- a/metastable_baselines/distributions/distributions.py +++ b/metastable_baselines/distributions/distributions.py @@ -41,8 +41,8 @@ class ParametrizationType(Enum): CHOL = 1 SPHERICAL_CHOL = 2 # Not (yet?) implemented: - #GIVENS = 3 - #NNLN_EIGEN = 4 + # GIVENS = 3 + # NNLN_EIGEN = 4 class EnforcePositiveType(Enum): @@ -360,13 +360,15 @@ class CholNet(nn.Module): raise Exception() def _chol_from_flat(self, flat_chol): - chol = fill_triangular(flat_chol).expand(self._flat_chol_len, -1, -1) + # chol = fill_triangular(flat_chol).expand(self._flat_chol_len, -1, -1) + chol = fill_triangular(flat_chol) return self._ensure_diagonal_positive(chol) def _chol_from_flat_sphe_chol(self, flat_sphe_chol): pos_flat_sphe_chol = self._ensure_positive_func(flat_sphe_chol) - sphe_chol = fill_triangular(pos_flat_sphe_chol).expand( - self._flat_chol_len, -1, -1) + # sphe_chol = fill_triangular(pos_flat_sphe_chol).expand( + # self._flat_chol_len, -1, -1) + sphe_chol = fill_triangular(pos_flat_sphe_chol) chol = self._chol_from_sphe_chol(sphe_chol) return chol diff --git a/metastable_baselines/misc/rollout_buffer.py b/metastable_baselines/misc/rollout_buffer.py index 8bec087..ddb65f9 100644 --- a/metastable_baselines/misc/rollout_buffer.py +++ b/metastable_baselines/misc/rollout_buffer.py @@ -26,7 +26,7 @@ class GaussianRolloutBufferSamples(NamedTuple): advantages: th.Tensor returns: th.Tensor means: th.Tensor - stds: th.Tensor + chols: th.Tensor class GaussianRolloutBuffer(RolloutBuffer): @@ -56,7 +56,7 @@ class GaussianRolloutBuffer(RolloutBuffer): def reset(self) -> None: self.means = np.zeros( (self.buffer_size, self.n_envs) + self.action_space.shape, dtype=np.float32) - self.stds = np.zeros( + self.chols = np.zeros( (self.buffer_size, self.n_envs) + self.cov_shape, dtype=np.float32) super().reset() @@ -69,7 +69,7 @@ class GaussianRolloutBuffer(RolloutBuffer): value: th.Tensor, log_prob: th.Tensor, mean: th.Tensor, - std: th.Tensor, + chol: th.Tensor, ) -> None: """ :param obs: Observation @@ -80,8 +80,8 @@ class GaussianRolloutBuffer(RolloutBuffer): following the current policy. :param log_prob: log probability of the action following the current policy. - :param mean: Foo - :param std: Bar + :param mean: + :param chol: """ if len(log_prob.shape) == 0: @@ -100,7 +100,7 @@ class GaussianRolloutBuffer(RolloutBuffer): self.values[self.pos] = value.clone().cpu().numpy().flatten() self.log_probs[self.pos] = log_prob.clone().cpu().numpy() self.means[self.pos] = mean.clone().cpu().numpy() - self.stds[self.pos] = std.clone().cpu().numpy() + self.chols[self.pos] = chol.clone().cpu().numpy() self.pos += 1 if self.pos == self.buffer_size: self.full = True @@ -114,7 +114,8 @@ class GaussianRolloutBuffer(RolloutBuffer): self.advantages[batch_inds].flatten(), self.returns[batch_inds].flatten(), self.means[batch_inds].reshape((len(batch_inds), -1)), - self.stds[batch_inds].reshape((len(batch_inds), -1)), + self.chols[batch_inds].reshape( + (len(batch_inds),) + self.cov_shape), ) return GaussianRolloutBufferSamples(*tuple(map(self.to_torch, data))) @@ -181,7 +182,7 @@ class GaussianRolloutCollectorAuxclass(): obs_tensor = obs_as_tensor(self._last_obs, self.device) actions, values, log_probs = self.policy(obs_tensor) dist = self.policy.get_distribution(obs_tensor).distribution - mean, std = dist.mean, dist.stddev + mean, chol = get_mean_and_chol(dist) actions = actions.cpu().numpy() # Rescale and perform action @@ -223,7 +224,7 @@ class GaussianRolloutCollectorAuxclass(): rewards[idx] += self.gamma * terminal_value rollout_buffer.add(self._last_obs, actions, rewards, - self._last_episode_starts, values, log_probs, mean, std) + self._last_episode_starts, values, log_probs, mean, chol) self._last_obs = new_obs self._last_episode_starts = dones diff --git a/metastable_baselines/ppo/ppo.py b/metastable_baselines/ppo/ppo.py index dd86154..a5eacfd 100644 --- a/metastable_baselines/ppo/ppo.py +++ b/metastable_baselines/ppo/ppo.py @@ -246,7 +246,7 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm): p = pol._get_action_dist_from_latent(latent_pi) p_dist = p.distribution q_dist = new_dist_like( - p_dist, rollout_data.means, rollout_data.stds) + p_dist, rollout_data.means, rollout_data.chols) proj_p = self.projection(p_dist, q_dist, self._global_steps) if isinstance(p_dist, th.distributions.Normal): # Normal uses a weird mapping from dimensions into batch_shape diff --git a/test.py b/test.py index 65b5b71..27cbdb5 100755 --- a/test.py +++ b/test.py @@ -26,8 +26,9 @@ def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=2_000_000, showRes=Tru ppo = PPO( MlpPolicy, env, - policy_kwargs={'dist_kwargs': {'neural_strength': Strength.DIAG, 'cov_strength': Strength.DIAG, 'parameterization_type': - ParametrizationType.NONE, 'enforce_positive_type': EnforcePositiveType.ABS, 'prob_squashing_type': ProbSquashingType.NONE}}, + projection=FrobeniusProjectionLayer(), + policy_kwargs={'dist_kwargs': {'neural_strength': Strength.FULL, 'cov_strength': Strength.FULL, 'parameterization_type': + ParametrizationType.CHOL, 'enforce_positive_type': EnforcePositiveType.ABS, 'prob_squashing_type': ProbSquashingType.NONE}}, verbose=0, tensorboard_log=root_path+"/logs_tb/" + env_name+"/ppo"+(['', '_sde'][use_sde])+"/",