Trying to get kl to work
This commit is contained in:
parent
9d7ce73a0b
commit
84d1cda96c
@ -7,6 +7,9 @@ from gym import spaces
|
||||
from stable_baselines3.common.buffers import RolloutBuffer
|
||||
from stable_baselines3.common.vec_env import VecNormalize
|
||||
|
||||
# TRL requires the origina mean and covariance from the policy when the datapoint was created.
|
||||
# GaussianRolloutBuffer extends the RolloutBuffer by these two fields
|
||||
|
||||
|
||||
class GaussianRolloutBufferSamples(NamedTuple):
|
||||
observations: th.Tensor
|
||||
@ -29,18 +32,24 @@ class GaussianRolloutBuffer(RolloutBuffer):
|
||||
gae_lambda: float = 1,
|
||||
gamma: float = 0.99,
|
||||
n_envs: int = 1,
|
||||
cov_shape=None,
|
||||
):
|
||||
|
||||
super().__init__(buffer_size, observation_space, action_space,
|
||||
device, n_envs=n_envs, gae_lambda=gae_lambda, gamma=gamma)
|
||||
self.means, self.stds = None, None
|
||||
# TODO: Correct shape for full cov matrix
|
||||
# self.action_space.shape + self.action_space.shape
|
||||
|
||||
if cov_shape == None:
|
||||
cov_shape = self.action_space.shape
|
||||
self.cov_shape = cov_shape
|
||||
|
||||
def reset(self) -> None:
|
||||
self.means = np.zeros(
|
||||
(self.buffer_size, self.n_envs) + self.action_space.shape, dtype=np.float32)
|
||||
self.stds = np.zeros(
|
||||
# (self.buffer_size, self.n_envs) + self.action_space.shape + self.action_space.shape, dtype=np.float32)
|
||||
(self.buffer_size, self.n_envs) + self.action_space.shape, dtype=np.float32)
|
||||
(self.buffer_size, self.n_envs) + self.cov_shape, dtype=np.float32)
|
||||
super().reset()
|
||||
|
||||
def add(
|
||||
|
Loading…
Reference in New Issue
Block a user