Added tmp (untested) vlearn objective impl
This commit is contained in:
parent
cb48badcff
commit
25988bab54
106
fancy_rl/objectives/vlearn.py
Normal file
106
fancy_rl/objectives/vlearn.py
Normal file
@ -0,0 +1,106 @@
|
||||
import torch
|
||||
from torchrl.objectives import LossModule
|
||||
from torch.distributions import Normal
|
||||
|
||||
class VLEARNLoss(LossModule):
|
||||
def __init__(
|
||||
self,
|
||||
actor_network,
|
||||
critic_network,
|
||||
old_actor_network,
|
||||
gamma=0.99,
|
||||
lmbda=0.95,
|
||||
entropy_coef=0.01,
|
||||
critic_coef=0.5,
|
||||
normalize_advantage=True,
|
||||
eps=1e-8,
|
||||
delta=0.1
|
||||
):
|
||||
super().__init__()
|
||||
self.actor_network = actor_network
|
||||
self.critic_network = critic_network
|
||||
self.old_actor_network = old_actor_network
|
||||
self.gamma = gamma
|
||||
self.lmbda = lmbda
|
||||
self.entropy_coef = entropy_coef
|
||||
self.critic_coef = critic_coef
|
||||
self.normalize_advantage = normalize_advantage
|
||||
self.eps = eps
|
||||
self.delta = delta
|
||||
|
||||
def forward(self, tensordict):
|
||||
# Compute returns and advantages
|
||||
with torch.no_grad():
|
||||
returns = self.compute_returns(tensordict)
|
||||
values = self.critic_network(tensordict)["state_value"]
|
||||
advantages = returns - values
|
||||
if self.normalize_advantage:
|
||||
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
||||
|
||||
# Compute actor loss
|
||||
new_td = self.actor_network(tensordict)
|
||||
old_td = self.old_actor_network(tensordict)
|
||||
|
||||
new_dist = Normal(new_td["loc"], new_td["scale"])
|
||||
old_dist = Normal(old_td["loc"], old_td["scale"])
|
||||
|
||||
new_log_prob = new_dist.log_prob(tensordict["action"]).sum(-1)
|
||||
old_log_prob = old_dist.log_prob(tensordict["action"]).sum(-1)
|
||||
|
||||
ratio = torch.exp(new_log_prob - old_log_prob)
|
||||
|
||||
# Compute projection
|
||||
kl = torch.distributions.kl.kl_divergence(new_dist, old_dist).sum(-1)
|
||||
alpha = torch.where(kl > self.delta,
|
||||
torch.sqrt(self.delta / (kl + self.eps)),
|
||||
torch.ones_like(kl))
|
||||
proj_loc = alpha.unsqueeze(-1) * new_td["loc"] + (1 - alpha.unsqueeze(-1)) * old_td["loc"]
|
||||
proj_scale = torch.sqrt(alpha.unsqueeze(-1)**2 * new_td["scale"]**2 + (1 - alpha.unsqueeze(-1))**2 * old_td["scale"]**2)
|
||||
proj_dist = Normal(proj_loc, proj_scale)
|
||||
|
||||
proj_log_prob = proj_dist.log_prob(tensordict["action"]).sum(-1)
|
||||
proj_ratio = torch.exp(proj_log_prob - old_log_prob)
|
||||
|
||||
policy_loss = -torch.min(
|
||||
ratio * advantages,
|
||||
proj_ratio * advantages
|
||||
).mean()
|
||||
|
||||
# Compute critic loss
|
||||
value_pred = self.critic_network(tensordict)["state_value"]
|
||||
critic_loss = 0.5 * (returns - value_pred).pow(2).mean()
|
||||
|
||||
# Compute entropy loss
|
||||
entropy_loss = -self.entropy_coef * new_dist.entropy().mean()
|
||||
|
||||
# Combine losses
|
||||
loss = policy_loss + self.critic_coef * critic_loss + entropy_loss
|
||||
|
||||
return {
|
||||
"loss": loss,
|
||||
"policy_loss": policy_loss,
|
||||
"critic_loss": critic_loss,
|
||||
"entropy_loss": entropy_loss,
|
||||
}
|
||||
|
||||
def compute_returns(self, tensordict):
|
||||
rewards = tensordict["reward"]
|
||||
dones = tensordict["done"]
|
||||
values = self.critic_network(tensordict)["state_value"]
|
||||
|
||||
returns = torch.zeros_like(rewards)
|
||||
advantages = torch.zeros_like(rewards)
|
||||
last_gae_lam = 0
|
||||
|
||||
for t in reversed(range(len(rewards))):
|
||||
if t == len(rewards) - 1:
|
||||
next_value = 0
|
||||
else:
|
||||
next_value = values[t + 1]
|
||||
|
||||
delta = rewards[t] + self.gamma * next_value * (1 - dones[t]) - values[t]
|
||||
advantages[t] = last_gae_lam = delta + self.gamma * self.lmbda * (1 - dones[t]) * last_gae_lam
|
||||
|
||||
returns = advantages + values
|
||||
|
||||
return returns
|
Loading…
Reference in New Issue
Block a user