Trying to get kl to work
This commit is contained in:
		
							parent
							
								
									9d7ce73a0b
								
							
						
					
					
						commit
						84d1cda96c
					
				@ -7,6 +7,9 @@ from gym import spaces
 | 
			
		||||
from stable_baselines3.common.buffers import RolloutBuffer
 | 
			
		||||
from stable_baselines3.common.vec_env import VecNormalize
 | 
			
		||||
 | 
			
		||||
# TRL requires the origina mean and covariance from the policy when the datapoint was created.
 | 
			
		||||
# GaussianRolloutBuffer extends the RolloutBuffer by these two fields
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class GaussianRolloutBufferSamples(NamedTuple):
 | 
			
		||||
    observations: th.Tensor
 | 
			
		||||
@ -29,18 +32,24 @@ class GaussianRolloutBuffer(RolloutBuffer):
 | 
			
		||||
        gae_lambda: float = 1,
 | 
			
		||||
        gamma: float = 0.99,
 | 
			
		||||
        n_envs: int = 1,
 | 
			
		||||
        cov_shape=None,
 | 
			
		||||
    ):
 | 
			
		||||
 | 
			
		||||
        super().__init__(buffer_size, observation_space, action_space,
 | 
			
		||||
                         device, n_envs=n_envs, gae_lambda=gae_lambda, gamma=gamma)
 | 
			
		||||
        self.means, self.stds = None, None
 | 
			
		||||
        # TODO: Correct shape for full cov matrix
 | 
			
		||||
        # self.action_space.shape + self.action_space.shape
 | 
			
		||||
 | 
			
		||||
        if cov_shape == None:
 | 
			
		||||
            cov_shape = self.action_space.shape
 | 
			
		||||
        self.cov_shape = cov_shape
 | 
			
		||||
 | 
			
		||||
    def reset(self) -> None:
 | 
			
		||||
        self.means = np.zeros(
 | 
			
		||||
            (self.buffer_size, self.n_envs) + self.action_space.shape, dtype=np.float32)
 | 
			
		||||
        self.stds = np.zeros(
 | 
			
		||||
            # (self.buffer_size, self.n_envs) + self.action_space.shape + self.action_space.shape, dtype=np.float32)
 | 
			
		||||
            (self.buffer_size, self.n_envs) + self.action_space.shape, dtype=np.float32)
 | 
			
		||||
            (self.buffer_size, self.n_envs) + self.cov_shape, dtype=np.float32)
 | 
			
		||||
        super().reset()
 | 
			
		||||
 | 
			
		||||
    def add(
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user