Smashing bugs (dimension mismatch between Normal and
Independent/MultivariateNormal)
This commit is contained in:
parent
ab557a8856
commit
a86d19053d
@ -360,13 +360,15 @@ class CholNet(nn.Module):
|
||||
raise Exception()
|
||||
|
||||
def _chol_from_flat(self, flat_chol):
|
||||
chol = fill_triangular(flat_chol).expand(self._flat_chol_len, -1, -1)
|
||||
# chol = fill_triangular(flat_chol).expand(self._flat_chol_len, -1, -1)
|
||||
chol = fill_triangular(flat_chol)
|
||||
return self._ensure_diagonal_positive(chol)
|
||||
|
||||
def _chol_from_flat_sphe_chol(self, flat_sphe_chol):
|
||||
pos_flat_sphe_chol = self._ensure_positive_func(flat_sphe_chol)
|
||||
sphe_chol = fill_triangular(pos_flat_sphe_chol).expand(
|
||||
self._flat_chol_len, -1, -1)
|
||||
# sphe_chol = fill_triangular(pos_flat_sphe_chol).expand(
|
||||
# self._flat_chol_len, -1, -1)
|
||||
sphe_chol = fill_triangular(pos_flat_sphe_chol)
|
||||
chol = self._chol_from_sphe_chol(sphe_chol)
|
||||
return chol
|
||||
|
||||
|
@ -26,7 +26,7 @@ class GaussianRolloutBufferSamples(NamedTuple):
|
||||
advantages: th.Tensor
|
||||
returns: th.Tensor
|
||||
means: th.Tensor
|
||||
stds: th.Tensor
|
||||
chols: th.Tensor
|
||||
|
||||
|
||||
class GaussianRolloutBuffer(RolloutBuffer):
|
||||
@ -56,7 +56,7 @@ class GaussianRolloutBuffer(RolloutBuffer):
|
||||
def reset(self) -> None:
|
||||
self.means = np.zeros(
|
||||
(self.buffer_size, self.n_envs) + self.action_space.shape, dtype=np.float32)
|
||||
self.stds = np.zeros(
|
||||
self.chols = np.zeros(
|
||||
(self.buffer_size, self.n_envs) + self.cov_shape, dtype=np.float32)
|
||||
super().reset()
|
||||
|
||||
@ -69,7 +69,7 @@ class GaussianRolloutBuffer(RolloutBuffer):
|
||||
value: th.Tensor,
|
||||
log_prob: th.Tensor,
|
||||
mean: th.Tensor,
|
||||
std: th.Tensor,
|
||||
chol: th.Tensor,
|
||||
) -> None:
|
||||
"""
|
||||
:param obs: Observation
|
||||
@ -80,8 +80,8 @@ class GaussianRolloutBuffer(RolloutBuffer):
|
||||
following the current policy.
|
||||
:param log_prob: log probability of the action
|
||||
following the current policy.
|
||||
:param mean: Foo
|
||||
:param std: Bar
|
||||
:param mean:
|
||||
:param chol:
|
||||
"""
|
||||
|
||||
if len(log_prob.shape) == 0:
|
||||
@ -100,7 +100,7 @@ class GaussianRolloutBuffer(RolloutBuffer):
|
||||
self.values[self.pos] = value.clone().cpu().numpy().flatten()
|
||||
self.log_probs[self.pos] = log_prob.clone().cpu().numpy()
|
||||
self.means[self.pos] = mean.clone().cpu().numpy()
|
||||
self.stds[self.pos] = std.clone().cpu().numpy()
|
||||
self.chols[self.pos] = chol.clone().cpu().numpy()
|
||||
self.pos += 1
|
||||
if self.pos == self.buffer_size:
|
||||
self.full = True
|
||||
@ -114,7 +114,8 @@ class GaussianRolloutBuffer(RolloutBuffer):
|
||||
self.advantages[batch_inds].flatten(),
|
||||
self.returns[batch_inds].flatten(),
|
||||
self.means[batch_inds].reshape((len(batch_inds), -1)),
|
||||
self.stds[batch_inds].reshape((len(batch_inds), -1)),
|
||||
self.chols[batch_inds].reshape(
|
||||
(len(batch_inds),) + self.cov_shape),
|
||||
)
|
||||
return GaussianRolloutBufferSamples(*tuple(map(self.to_torch, data)))
|
||||
|
||||
@ -181,7 +182,7 @@ class GaussianRolloutCollectorAuxclass():
|
||||
obs_tensor = obs_as_tensor(self._last_obs, self.device)
|
||||
actions, values, log_probs = self.policy(obs_tensor)
|
||||
dist = self.policy.get_distribution(obs_tensor).distribution
|
||||
mean, std = dist.mean, dist.stddev
|
||||
mean, chol = get_mean_and_chol(dist)
|
||||
actions = actions.cpu().numpy()
|
||||
|
||||
# Rescale and perform action
|
||||
@ -223,7 +224,7 @@ class GaussianRolloutCollectorAuxclass():
|
||||
rewards[idx] += self.gamma * terminal_value
|
||||
|
||||
rollout_buffer.add(self._last_obs, actions, rewards,
|
||||
self._last_episode_starts, values, log_probs, mean, std)
|
||||
self._last_episode_starts, values, log_probs, mean, chol)
|
||||
self._last_obs = new_obs
|
||||
self._last_episode_starts = dones
|
||||
|
||||
|
@ -246,7 +246,7 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
|
||||
p = pol._get_action_dist_from_latent(latent_pi)
|
||||
p_dist = p.distribution
|
||||
q_dist = new_dist_like(
|
||||
p_dist, rollout_data.means, rollout_data.stds)
|
||||
p_dist, rollout_data.means, rollout_data.chols)
|
||||
proj_p = self.projection(p_dist, q_dist, self._global_steps)
|
||||
if isinstance(p_dist, th.distributions.Normal):
|
||||
# Normal uses a weird mapping from dimensions into batch_shape
|
||||
|
5
test.py
5
test.py
@ -26,8 +26,9 @@ def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=2_000_000, showRes=Tru
|
||||
ppo = PPO(
|
||||
MlpPolicy,
|
||||
env,
|
||||
policy_kwargs={'dist_kwargs': {'neural_strength': Strength.DIAG, 'cov_strength': Strength.DIAG, 'parameterization_type':
|
||||
ParametrizationType.NONE, 'enforce_positive_type': EnforcePositiveType.ABS, 'prob_squashing_type': ProbSquashingType.NONE}},
|
||||
projection=FrobeniusProjectionLayer(),
|
||||
policy_kwargs={'dist_kwargs': {'neural_strength': Strength.FULL, 'cov_strength': Strength.FULL, 'parameterization_type':
|
||||
ParametrizationType.CHOL, 'enforce_positive_type': EnforcePositiveType.ABS, 'prob_squashing_type': ProbSquashingType.NONE}},
|
||||
verbose=0,
|
||||
tensorboard_log=root_path+"/logs_tb/" +
|
||||
env_name+"/ppo"+(['', '_sde'][use_sde])+"/",
|
||||
|
Loading…
Reference in New Issue
Block a user