Support Full Cov in Buffer
This commit is contained in:
parent
1321e47b81
commit
4831fbb7e1
@ -87,7 +87,9 @@ class BetterRolloutBuffer(RolloutBuffer):
|
||||
gae_lambda: float = 1,
|
||||
gamma: float = 0.99,
|
||||
n_envs: int = 1,
|
||||
full_cov: bool = False,
|
||||
):
|
||||
self.full_cov = full_cov
|
||||
super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
|
||||
self.gae_lambda = gae_lambda
|
||||
self.gamma = gamma
|
||||
@ -103,7 +105,10 @@ class BetterRolloutBuffer(RolloutBuffer):
|
||||
self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
||||
self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
||||
self.means = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
|
||||
self.cov_decomps = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
|
||||
if self.full_cov:
|
||||
self.cov_decomps = np.zeros((self.buffer_size, self.n_envs, self.action_dim, self.action_dim), dtype=np.float32)
|
||||
else:
|
||||
self.cov_decomps = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
|
||||
self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
||||
self.generator_ready = False
|
||||
super().reset()
|
||||
@ -156,6 +161,7 @@ class BetterRolloutBuffer(RolloutBuffer):
|
||||
mean: th.Tensor,
|
||||
cov_decomp: th.Tensor,
|
||||
) -> None:
|
||||
|
||||
"""
|
||||
:param obs: Observation
|
||||
:param action: Action
|
||||
|
@ -120,7 +120,7 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
|
||||
|
||||
if 'dist_kwargs' not in self.policy_kwargs:
|
||||
self.policy_kwargs['dist_kwargs'] = {}
|
||||
self.policy_kwargs['dist_kwargs']['n_envs'] = len(self.env.envs)
|
||||
self.policy_kwargs['dist_kwargs']['n_envs'] = self.env.num_envs if hasattr(self.env, 'num_envs') else 1
|
||||
|
||||
self.rollout_buffer_class = None
|
||||
self.rollout_buffer_kwargs = {}
|
||||
@ -135,6 +135,14 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
|
||||
else:
|
||||
self.rollout_buffer_class = BetterRolloutBuffer
|
||||
|
||||
self.policy = self.policy_class( # pytype:disable=not-instantiable
|
||||
self.observation_space,
|
||||
self.action_space,
|
||||
self.lr_schedule,
|
||||
use_sde=self.use_sde,
|
||||
use_pca=self.use_pca,
|
||||
**self.policy_kwargs # pytype:disable=not-instantiable
|
||||
)
|
||||
self.rollout_buffer = self.rollout_buffer_class(
|
||||
self.n_steps,
|
||||
self.observation_space,
|
||||
@ -143,16 +151,9 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
|
||||
gamma=self.gamma,
|
||||
gae_lambda=self.gae_lambda,
|
||||
n_envs=self.n_envs,
|
||||
full_cov=(self.use_pca and self.policy.action_dist.is_full()),
|
||||
**self.rollout_buffer_kwargs,
|
||||
)
|
||||
self.policy = self.policy_class( # pytype:disable=not-instantiable
|
||||
self.observation_space,
|
||||
self.action_space,
|
||||
self.lr_schedule,
|
||||
use_sde=self.use_sde,
|
||||
use_pca=self.use_pca,
|
||||
**self.policy_kwargs # pytype:disable=not-instantiable
|
||||
)
|
||||
self.policy = self.policy.to(self.device)
|
||||
|
||||
def collect_rollouts(
|
||||
@ -240,7 +241,7 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
|
||||
terminal_value = self.policy.predict_values(terminal_obs)[0]
|
||||
rewards[idx] += self.gamma * terminal_value
|
||||
|
||||
rollout_buffer.add(self._last_obs, actions, rewards, self._last_episode_starts, values, log_probs, distributions.distribution.mean, distributions.distribution.scale)
|
||||
rollout_buffer.add(self._last_obs, actions, rewards, self._last_episode_starts, values, log_probs, distributions.distribution.mean, distributions.distribution.scale if hasattr(distributions.distribution, 'scale') else distributions.distribution.scale_tril)
|
||||
self._last_obs = new_obs
|
||||
self._last_episode_starts = dones
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user