Trying to get kl to work
This commit is contained in:
		
							parent
							
								
									9d7ce73a0b
								
							
						
					
					
						commit
						84d1cda96c
					
				@ -7,6 +7,9 @@ from gym import spaces
 | 
				
			|||||||
from stable_baselines3.common.buffers import RolloutBuffer
 | 
					from stable_baselines3.common.buffers import RolloutBuffer
 | 
				
			||||||
from stable_baselines3.common.vec_env import VecNormalize
 | 
					from stable_baselines3.common.vec_env import VecNormalize
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# TRL requires the origina mean and covariance from the policy when the datapoint was created.
 | 
				
			||||||
 | 
					# GaussianRolloutBuffer extends the RolloutBuffer by these two fields
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class GaussianRolloutBufferSamples(NamedTuple):
 | 
					class GaussianRolloutBufferSamples(NamedTuple):
 | 
				
			||||||
    observations: th.Tensor
 | 
					    observations: th.Tensor
 | 
				
			||||||
@ -29,18 +32,24 @@ class GaussianRolloutBuffer(RolloutBuffer):
 | 
				
			|||||||
        gae_lambda: float = 1,
 | 
					        gae_lambda: float = 1,
 | 
				
			||||||
        gamma: float = 0.99,
 | 
					        gamma: float = 0.99,
 | 
				
			||||||
        n_envs: int = 1,
 | 
					        n_envs: int = 1,
 | 
				
			||||||
 | 
					        cov_shape=None,
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        super().__init__(buffer_size, observation_space, action_space,
 | 
					        super().__init__(buffer_size, observation_space, action_space,
 | 
				
			||||||
                         device, n_envs=n_envs, gae_lambda=gae_lambda, gamma=gamma)
 | 
					                         device, n_envs=n_envs, gae_lambda=gae_lambda, gamma=gamma)
 | 
				
			||||||
        self.means, self.stds = None, None
 | 
					        self.means, self.stds = None, None
 | 
				
			||||||
 | 
					        # TODO: Correct shape for full cov matrix
 | 
				
			||||||
 | 
					        # self.action_space.shape + self.action_space.shape
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if cov_shape == None:
 | 
				
			||||||
 | 
					            cov_shape = self.action_space.shape
 | 
				
			||||||
 | 
					        self.cov_shape = cov_shape
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def reset(self) -> None:
 | 
					    def reset(self) -> None:
 | 
				
			||||||
        self.means = np.zeros(
 | 
					        self.means = np.zeros(
 | 
				
			||||||
            (self.buffer_size, self.n_envs) + self.action_space.shape, dtype=np.float32)
 | 
					            (self.buffer_size, self.n_envs) + self.action_space.shape, dtype=np.float32)
 | 
				
			||||||
        self.stds = np.zeros(
 | 
					        self.stds = np.zeros(
 | 
				
			||||||
            # (self.buffer_size, self.n_envs) + self.action_space.shape + self.action_space.shape, dtype=np.float32)
 | 
					            (self.buffer_size, self.n_envs) + self.cov_shape, dtype=np.float32)
 | 
				
			||||||
            (self.buffer_size, self.n_envs) + self.action_space.shape, dtype=np.float32)
 | 
					 | 
				
			||||||
        super().reset()
 | 
					        super().reset()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def add(
 | 
					    def add(
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
		Reference in New Issue
	
	Block a user