Support Full Cov in Buffer
This commit is contained in:
parent
1321e47b81
commit
4831fbb7e1
@ -87,7 +87,9 @@ class BetterRolloutBuffer(RolloutBuffer):
|
|||||||
gae_lambda: float = 1,
|
gae_lambda: float = 1,
|
||||||
gamma: float = 0.99,
|
gamma: float = 0.99,
|
||||||
n_envs: int = 1,
|
n_envs: int = 1,
|
||||||
|
full_cov: bool = False,
|
||||||
):
|
):
|
||||||
|
self.full_cov = full_cov
|
||||||
super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
|
super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
|
||||||
self.gae_lambda = gae_lambda
|
self.gae_lambda = gae_lambda
|
||||||
self.gamma = gamma
|
self.gamma = gamma
|
||||||
@ -103,7 +105,10 @@ class BetterRolloutBuffer(RolloutBuffer):
|
|||||||
self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
||||||
self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
||||||
self.means = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
|
self.means = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
|
||||||
self.cov_decomps = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
|
if self.full_cov:
|
||||||
|
self.cov_decomps = np.zeros((self.buffer_size, self.n_envs, self.action_dim, self.action_dim), dtype=np.float32)
|
||||||
|
else:
|
||||||
|
self.cov_decomps = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
|
||||||
self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
||||||
self.generator_ready = False
|
self.generator_ready = False
|
||||||
super().reset()
|
super().reset()
|
||||||
@ -156,6 +161,7 @@ class BetterRolloutBuffer(RolloutBuffer):
|
|||||||
mean: th.Tensor,
|
mean: th.Tensor,
|
||||||
cov_decomp: th.Tensor,
|
cov_decomp: th.Tensor,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
:param obs: Observation
|
:param obs: Observation
|
||||||
:param action: Action
|
:param action: Action
|
||||||
|
@ -120,7 +120,7 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
|
|||||||
|
|
||||||
if 'dist_kwargs' not in self.policy_kwargs:
|
if 'dist_kwargs' not in self.policy_kwargs:
|
||||||
self.policy_kwargs['dist_kwargs'] = {}
|
self.policy_kwargs['dist_kwargs'] = {}
|
||||||
self.policy_kwargs['dist_kwargs']['n_envs'] = len(self.env.envs)
|
self.policy_kwargs['dist_kwargs']['n_envs'] = self.env.num_envs if hasattr(self.env, 'num_envs') else 1
|
||||||
|
|
||||||
self.rollout_buffer_class = None
|
self.rollout_buffer_class = None
|
||||||
self.rollout_buffer_kwargs = {}
|
self.rollout_buffer_kwargs = {}
|
||||||
@ -135,6 +135,14 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
|
|||||||
else:
|
else:
|
||||||
self.rollout_buffer_class = BetterRolloutBuffer
|
self.rollout_buffer_class = BetterRolloutBuffer
|
||||||
|
|
||||||
|
self.policy = self.policy_class( # pytype:disable=not-instantiable
|
||||||
|
self.observation_space,
|
||||||
|
self.action_space,
|
||||||
|
self.lr_schedule,
|
||||||
|
use_sde=self.use_sde,
|
||||||
|
use_pca=self.use_pca,
|
||||||
|
**self.policy_kwargs # pytype:disable=not-instantiable
|
||||||
|
)
|
||||||
self.rollout_buffer = self.rollout_buffer_class(
|
self.rollout_buffer = self.rollout_buffer_class(
|
||||||
self.n_steps,
|
self.n_steps,
|
||||||
self.observation_space,
|
self.observation_space,
|
||||||
@ -143,16 +151,9 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
|
|||||||
gamma=self.gamma,
|
gamma=self.gamma,
|
||||||
gae_lambda=self.gae_lambda,
|
gae_lambda=self.gae_lambda,
|
||||||
n_envs=self.n_envs,
|
n_envs=self.n_envs,
|
||||||
|
full_cov=(self.use_pca and self.policy.action_dist.is_full()),
|
||||||
**self.rollout_buffer_kwargs,
|
**self.rollout_buffer_kwargs,
|
||||||
)
|
)
|
||||||
self.policy = self.policy_class( # pytype:disable=not-instantiable
|
|
||||||
self.observation_space,
|
|
||||||
self.action_space,
|
|
||||||
self.lr_schedule,
|
|
||||||
use_sde=self.use_sde,
|
|
||||||
use_pca=self.use_pca,
|
|
||||||
**self.policy_kwargs # pytype:disable=not-instantiable
|
|
||||||
)
|
|
||||||
self.policy = self.policy.to(self.device)
|
self.policy = self.policy.to(self.device)
|
||||||
|
|
||||||
def collect_rollouts(
|
def collect_rollouts(
|
||||||
@ -240,7 +241,7 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
|
|||||||
terminal_value = self.policy.predict_values(terminal_obs)[0]
|
terminal_value = self.policy.predict_values(terminal_obs)[0]
|
||||||
rewards[idx] += self.gamma * terminal_value
|
rewards[idx] += self.gamma * terminal_value
|
||||||
|
|
||||||
rollout_buffer.add(self._last_obs, actions, rewards, self._last_episode_starts, values, log_probs, distributions.distribution.mean, distributions.distribution.scale)
|
rollout_buffer.add(self._last_obs, actions, rewards, self._last_episode_starts, values, log_probs, distributions.distribution.mean, distributions.distribution.scale if hasattr(distributions.distribution, 'scale') else distributions.distribution.scale_tril)
|
||||||
self._last_obs = new_obs
|
self._last_obs = new_obs
|
||||||
self._last_episode_starts = dones
|
self._last_episode_starts = dones
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user