Trying to get kl to work

This commit is contained in:
Dominik Moritz Roth 2022-07-01 13:45:58 +02:00
parent 9d7ce73a0b
commit 84d1cda96c

View File

@ -7,6 +7,9 @@ from gym import spaces
from stable_baselines3.common.buffers import RolloutBuffer
from stable_baselines3.common.vec_env import VecNormalize
# TRL requires the origina mean and covariance from the policy when the datapoint was created.
# GaussianRolloutBuffer extends the RolloutBuffer by these two fields
class GaussianRolloutBufferSamples(NamedTuple):
observations: th.Tensor
@ -29,18 +32,24 @@ class GaussianRolloutBuffer(RolloutBuffer):
gae_lambda: float = 1,
gamma: float = 0.99,
n_envs: int = 1,
cov_shape=None,
):
super().__init__(buffer_size, observation_space, action_space,
device, n_envs=n_envs, gae_lambda=gae_lambda, gamma=gamma)
self.means, self.stds = None, None
# TODO: Correct shape for full cov matrix
# self.action_space.shape + self.action_space.shape
if cov_shape == None:
cov_shape = self.action_space.shape
self.cov_shape = cov_shape
def reset(self) -> None:
self.means = np.zeros(
(self.buffer_size, self.n_envs) + self.action_space.shape, dtype=np.float32)
self.stds = np.zeros(
# (self.buffer_size, self.n_envs) + self.action_space.shape + self.action_space.shape, dtype=np.float32)
(self.buffer_size, self.n_envs) + self.action_space.shape, dtype=np.float32)
(self.buffer_size, self.n_envs) + self.cov_shape, dtype=np.float32)
super().reset()
def add(