Support Full Cov in Buffer

This commit is contained in:
Dominik Moritz Roth 2024-03-30 14:42:41 +01:00
parent 1321e47b81
commit 4831fbb7e1
2 changed files with 18 additions and 11 deletions

View File

@ -87,7 +87,9 @@ class BetterRolloutBuffer(RolloutBuffer):
gae_lambda: float = 1,
gamma: float = 0.99,
n_envs: int = 1,
full_cov: bool = False,
):
self.full_cov = full_cov
super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
self.gae_lambda = gae_lambda
self.gamma = gamma
@ -103,6 +105,9 @@ class BetterRolloutBuffer(RolloutBuffer):
self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.means = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
if self.full_cov:
self.cov_decomps = np.zeros((self.buffer_size, self.n_envs, self.action_dim, self.action_dim), dtype=np.float32)
else:
self.cov_decomps = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.generator_ready = False
@ -156,6 +161,7 @@ class BetterRolloutBuffer(RolloutBuffer):
mean: th.Tensor,
cov_decomp: th.Tensor,
) -> None:
"""
:param obs: Observation
:param action: Action

View File

@ -120,7 +120,7 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
if 'dist_kwargs' not in self.policy_kwargs:
self.policy_kwargs['dist_kwargs'] = {}
self.policy_kwargs['dist_kwargs']['n_envs'] = len(self.env.envs)
self.policy_kwargs['dist_kwargs']['n_envs'] = self.env.num_envs if hasattr(self.env, 'num_envs') else 1
self.rollout_buffer_class = None
self.rollout_buffer_kwargs = {}
@ -135,6 +135,14 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
else:
self.rollout_buffer_class = BetterRolloutBuffer
self.policy = self.policy_class( # pytype:disable=not-instantiable
self.observation_space,
self.action_space,
self.lr_schedule,
use_sde=self.use_sde,
use_pca=self.use_pca,
**self.policy_kwargs # pytype:disable=not-instantiable
)
self.rollout_buffer = self.rollout_buffer_class(
self.n_steps,
self.observation_space,
@ -143,16 +151,9 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
gamma=self.gamma,
gae_lambda=self.gae_lambda,
n_envs=self.n_envs,
full_cov=(self.use_pca and self.policy.action_dist.is_full()),
**self.rollout_buffer_kwargs,
)
self.policy = self.policy_class( # pytype:disable=not-instantiable
self.observation_space,
self.action_space,
self.lr_schedule,
use_sde=self.use_sde,
use_pca=self.use_pca,
**self.policy_kwargs # pytype:disable=not-instantiable
)
self.policy = self.policy.to(self.device)
def collect_rollouts(
@ -240,7 +241,7 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
terminal_value = self.policy.predict_values(terminal_obs)[0]
rewards[idx] += self.gamma * terminal_value
rollout_buffer.add(self._last_obs, actions, rewards, self._last_episode_starts, values, log_probs, distributions.distribution.mean, distributions.distribution.scale)
rollout_buffer.add(self._last_obs, actions, rewards, self._last_episode_starts, values, log_probs, distributions.distribution.mean, distributions.distribution.scale if hasattr(distributions.distribution, 'scale') else distributions.distribution.scale_tril)
self._last_obs = new_obs
self._last_episode_starts = dones