vlearn loss draft
This commit is contained in:
		
							parent
							
								
									e106d8701f
								
							
						
					
					
						commit
						416c2036a5
					
				
							
								
								
									
										106
									
								
								fancy_rl/objectives/vlearn_loss.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										106
									
								
								fancy_rl/objectives/vlearn_loss.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,106 @@
 | 
				
			|||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from torchrl.objectives import LossModule
 | 
				
			||||||
 | 
					from torch.distributions import Normal
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class VLEARNLoss(LossModule):
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        actor_network,
 | 
				
			||||||
 | 
					        critic_network,
 | 
				
			||||||
 | 
					        old_actor_network,
 | 
				
			||||||
 | 
					        gamma=0.99,
 | 
				
			||||||
 | 
					        lmbda=0.95,
 | 
				
			||||||
 | 
					        entropy_coef=0.01,
 | 
				
			||||||
 | 
					        critic_coef=0.5,
 | 
				
			||||||
 | 
					        normalize_advantage=True,
 | 
				
			||||||
 | 
					        eps=1e-8,
 | 
				
			||||||
 | 
					        delta=0.1
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.actor_network = actor_network
 | 
				
			||||||
 | 
					        self.critic_network = critic_network
 | 
				
			||||||
 | 
					        self.old_actor_network = old_actor_network
 | 
				
			||||||
 | 
					        self.gamma = gamma
 | 
				
			||||||
 | 
					        self.lmbda = lmbda
 | 
				
			||||||
 | 
					        self.entropy_coef = entropy_coef
 | 
				
			||||||
 | 
					        self.critic_coef = critic_coef
 | 
				
			||||||
 | 
					        self.normalize_advantage = normalize_advantage
 | 
				
			||||||
 | 
					        self.eps = eps
 | 
				
			||||||
 | 
					        self.delta = delta
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, tensordict):
 | 
				
			||||||
 | 
					        # Compute returns and advantages
 | 
				
			||||||
 | 
					        with torch.no_grad():
 | 
				
			||||||
 | 
					            returns = self.compute_returns(tensordict)
 | 
				
			||||||
 | 
					            values = self.critic_network(tensordict)["state_value"]
 | 
				
			||||||
 | 
					            advantages = returns - values
 | 
				
			||||||
 | 
					            if self.normalize_advantage:
 | 
				
			||||||
 | 
					                advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Compute actor loss
 | 
				
			||||||
 | 
					        new_td = self.actor_network(tensordict)
 | 
				
			||||||
 | 
					        old_td = self.old_actor_network(tensordict)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        new_dist = Normal(new_td["loc"], new_td["scale"])
 | 
				
			||||||
 | 
					        old_dist = Normal(old_td["loc"], old_td["scale"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        new_log_prob = new_dist.log_prob(tensordict["action"]).sum(-1)
 | 
				
			||||||
 | 
					        old_log_prob = old_dist.log_prob(tensordict["action"]).sum(-1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        ratio = torch.exp(new_log_prob - old_log_prob)
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        # Compute projection
 | 
				
			||||||
 | 
					        kl = torch.distributions.kl.kl_divergence(new_dist, old_dist).sum(-1)
 | 
				
			||||||
 | 
					        alpha = torch.where(kl > self.delta, 
 | 
				
			||||||
 | 
					                            torch.sqrt(self.delta / (kl + self.eps)), 
 | 
				
			||||||
 | 
					                            torch.ones_like(kl))
 | 
				
			||||||
 | 
					        proj_loc = alpha.unsqueeze(-1) * new_td["loc"] + (1 - alpha.unsqueeze(-1)) * old_td["loc"]
 | 
				
			||||||
 | 
					        proj_scale = torch.sqrt(alpha.unsqueeze(-1)**2 * new_td["scale"]**2 + (1 - alpha.unsqueeze(-1))**2 * old_td["scale"]**2)
 | 
				
			||||||
 | 
					        proj_dist = Normal(proj_loc, proj_scale)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        proj_log_prob = proj_dist.log_prob(tensordict["action"]).sum(-1)
 | 
				
			||||||
 | 
					        proj_ratio = torch.exp(proj_log_prob - old_log_prob)
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        policy_loss = -torch.min(
 | 
				
			||||||
 | 
					            ratio * advantages,
 | 
				
			||||||
 | 
					            proj_ratio * advantages
 | 
				
			||||||
 | 
					        ).mean()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Compute critic loss
 | 
				
			||||||
 | 
					        value_pred = self.critic_network(tensordict)["state_value"]
 | 
				
			||||||
 | 
					        critic_loss = 0.5 * (returns - value_pred).pow(2).mean()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Compute entropy loss
 | 
				
			||||||
 | 
					        entropy_loss = -self.entropy_coef * new_dist.entropy().mean()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Combine losses
 | 
				
			||||||
 | 
					        loss = policy_loss + self.critic_coef * critic_loss + entropy_loss
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return {
 | 
				
			||||||
 | 
					            "loss": loss,
 | 
				
			||||||
 | 
					            "policy_loss": policy_loss,
 | 
				
			||||||
 | 
					            "critic_loss": critic_loss,
 | 
				
			||||||
 | 
					            "entropy_loss": entropy_loss,
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def compute_returns(self, tensordict):
 | 
				
			||||||
 | 
					        rewards = tensordict["reward"]
 | 
				
			||||||
 | 
					        dones = tensordict["done"]
 | 
				
			||||||
 | 
					        values = self.critic_network(tensordict)["state_value"]
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        returns = torch.zeros_like(rewards)
 | 
				
			||||||
 | 
					        advantages = torch.zeros_like(rewards)
 | 
				
			||||||
 | 
					        last_gae_lam = 0
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        for t in reversed(range(len(rewards))):
 | 
				
			||||||
 | 
					            if t == len(rewards) - 1:
 | 
				
			||||||
 | 
					                next_value = 0
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                next_value = values[t + 1]
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					            delta = rewards[t] + self.gamma * next_value * (1 - dones[t]) - values[t]
 | 
				
			||||||
 | 
					            advantages[t] = last_gae_lam = delta + self.gamma * self.lmbda * (1 - dones[t]) * last_gae_lam
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        returns = advantages + values
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        return returns
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user