Support Full Cov in Buffer

This commit is contained in:
Dominik Moritz Roth 2024-03-30 14:42:41 +01:00
parent 1321e47b81
commit 4831fbb7e1
2 changed files with 18 additions and 11 deletions

View File

@ -87,7 +87,9 @@ class BetterRolloutBuffer(RolloutBuffer):
gae_lambda: float = 1, gae_lambda: float = 1,
gamma: float = 0.99, gamma: float = 0.99,
n_envs: int = 1, n_envs: int = 1,
full_cov: bool = False,
): ):
self.full_cov = full_cov
super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs) super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
self.gae_lambda = gae_lambda self.gae_lambda = gae_lambda
self.gamma = gamma self.gamma = gamma
@ -103,7 +105,10 @@ class BetterRolloutBuffer(RolloutBuffer):
self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.means = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32) self.means = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
self.cov_decomps = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32) if self.full_cov:
self.cov_decomps = np.zeros((self.buffer_size, self.n_envs, self.action_dim, self.action_dim), dtype=np.float32)
else:
self.cov_decomps = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.generator_ready = False self.generator_ready = False
super().reset() super().reset()
@ -156,6 +161,7 @@ class BetterRolloutBuffer(RolloutBuffer):
mean: th.Tensor, mean: th.Tensor,
cov_decomp: th.Tensor, cov_decomp: th.Tensor,
) -> None: ) -> None:
""" """
:param obs: Observation :param obs: Observation
:param action: Action :param action: Action

View File

@ -120,7 +120,7 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
if 'dist_kwargs' not in self.policy_kwargs: if 'dist_kwargs' not in self.policy_kwargs:
self.policy_kwargs['dist_kwargs'] = {} self.policy_kwargs['dist_kwargs'] = {}
self.policy_kwargs['dist_kwargs']['n_envs'] = len(self.env.envs) self.policy_kwargs['dist_kwargs']['n_envs'] = self.env.num_envs if hasattr(self.env, 'num_envs') else 1
self.rollout_buffer_class = None self.rollout_buffer_class = None
self.rollout_buffer_kwargs = {} self.rollout_buffer_kwargs = {}
@ -135,6 +135,14 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
else: else:
self.rollout_buffer_class = BetterRolloutBuffer self.rollout_buffer_class = BetterRolloutBuffer
self.policy = self.policy_class( # pytype:disable=not-instantiable
self.observation_space,
self.action_space,
self.lr_schedule,
use_sde=self.use_sde,
use_pca=self.use_pca,
**self.policy_kwargs # pytype:disable=not-instantiable
)
self.rollout_buffer = self.rollout_buffer_class( self.rollout_buffer = self.rollout_buffer_class(
self.n_steps, self.n_steps,
self.observation_space, self.observation_space,
@ -143,16 +151,9 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
gamma=self.gamma, gamma=self.gamma,
gae_lambda=self.gae_lambda, gae_lambda=self.gae_lambda,
n_envs=self.n_envs, n_envs=self.n_envs,
full_cov=(self.use_pca and self.policy.action_dist.is_full()),
**self.rollout_buffer_kwargs, **self.rollout_buffer_kwargs,
) )
self.policy = self.policy_class( # pytype:disable=not-instantiable
self.observation_space,
self.action_space,
self.lr_schedule,
use_sde=self.use_sde,
use_pca=self.use_pca,
**self.policy_kwargs # pytype:disable=not-instantiable
)
self.policy = self.policy.to(self.device) self.policy = self.policy.to(self.device)
def collect_rollouts( def collect_rollouts(
@ -240,7 +241,7 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
terminal_value = self.policy.predict_values(terminal_obs)[0] terminal_value = self.policy.predict_values(terminal_obs)[0]
rewards[idx] += self.gamma * terminal_value rewards[idx] += self.gamma * terminal_value
rollout_buffer.add(self._last_obs, actions, rewards, self._last_episode_starts, values, log_probs, distributions.distribution.mean, distributions.distribution.scale) rollout_buffer.add(self._last_obs, actions, rewards, self._last_episode_starts, values, log_probs, distributions.distribution.mean, distributions.distribution.scale if hasattr(distributions.distribution, 'scale') else distributions.distribution.scale_tril)
self._last_obs = new_obs self._last_obs = new_obs
self._last_episode_starts = dones self._last_episode_starts = dones