Trying to get kl to work
This commit is contained in:
parent
9d7ce73a0b
commit
84d1cda96c
@ -7,6 +7,9 @@ from gym import spaces
|
|||||||
from stable_baselines3.common.buffers import RolloutBuffer
|
from stable_baselines3.common.buffers import RolloutBuffer
|
||||||
from stable_baselines3.common.vec_env import VecNormalize
|
from stable_baselines3.common.vec_env import VecNormalize
|
||||||
|
|
||||||
|
# TRL requires the origina mean and covariance from the policy when the datapoint was created.
|
||||||
|
# GaussianRolloutBuffer extends the RolloutBuffer by these two fields
|
||||||
|
|
||||||
|
|
||||||
class GaussianRolloutBufferSamples(NamedTuple):
|
class GaussianRolloutBufferSamples(NamedTuple):
|
||||||
observations: th.Tensor
|
observations: th.Tensor
|
||||||
@ -29,18 +32,24 @@ class GaussianRolloutBuffer(RolloutBuffer):
|
|||||||
gae_lambda: float = 1,
|
gae_lambda: float = 1,
|
||||||
gamma: float = 0.99,
|
gamma: float = 0.99,
|
||||||
n_envs: int = 1,
|
n_envs: int = 1,
|
||||||
|
cov_shape=None,
|
||||||
):
|
):
|
||||||
|
|
||||||
super().__init__(buffer_size, observation_space, action_space,
|
super().__init__(buffer_size, observation_space, action_space,
|
||||||
device, n_envs=n_envs, gae_lambda=gae_lambda, gamma=gamma)
|
device, n_envs=n_envs, gae_lambda=gae_lambda, gamma=gamma)
|
||||||
self.means, self.stds = None, None
|
self.means, self.stds = None, None
|
||||||
|
# TODO: Correct shape for full cov matrix
|
||||||
|
# self.action_space.shape + self.action_space.shape
|
||||||
|
|
||||||
|
if cov_shape == None:
|
||||||
|
cov_shape = self.action_space.shape
|
||||||
|
self.cov_shape = cov_shape
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
self.means = np.zeros(
|
self.means = np.zeros(
|
||||||
(self.buffer_size, self.n_envs) + self.action_space.shape, dtype=np.float32)
|
(self.buffer_size, self.n_envs) + self.action_space.shape, dtype=np.float32)
|
||||||
self.stds = np.zeros(
|
self.stds = np.zeros(
|
||||||
# (self.buffer_size, self.n_envs) + self.action_space.shape + self.action_space.shape, dtype=np.float32)
|
(self.buffer_size, self.n_envs) + self.cov_shape, dtype=np.float32)
|
||||||
(self.buffer_size, self.n_envs) + self.action_space.shape, dtype=np.float32)
|
|
||||||
super().reset()
|
super().reset()
|
||||||
|
|
||||||
def add(
|
def add(
|
||||||
|
Loading…
Reference in New Issue
Block a user