dppo/model/rl/gaussian_sac.py
Allen Z. Ren e0842e71dc
v0.5 to main (#10)
* v0.5 (#9)

* update idql configs

* update awr configs

* update dipo configs

* update qsm configs

* update dqm configs

* update project version to 0.5.0
2024-10-07 16:35:13 -04:00

89 lines
2.4 KiB
Python

"""
Soft Actor Critic (SAC) with Gaussian policy.
"""
import torch
import logging
from copy import deepcopy
import torch.nn.functional as F
from model.common.gaussian import GaussianModel
log = logging.getLogger(__name__)
class SAC_Gaussian(GaussianModel):
def __init__(
self,
actor,
critic,
**kwargs,
):
super().__init__(network=actor, **kwargs)
# initialize doubel critic networks
self.critic = critic.to(self.device)
# initialize double target networks
self.target_critic = deepcopy(self.critic).to(self.device)
def loss_critic(
self,
obs,
next_obs,
actions,
rewards,
terminated,
gamma,
alpha,
):
with torch.no_grad():
next_actions, next_logprobs = self.forward(
cond=next_obs,
deterministic=False,
get_logprob=True,
)
next_q1, next_q2 = self.target_critic(
next_obs,
next_actions,
)
next_q = torch.min(next_q1, next_q2) - alpha * next_logprobs
# target value
target_q = rewards + gamma * next_q * (1 - terminated)
current_q1, current_q2 = self.critic(obs, actions)
loss_critic = F.mse_loss(current_q1, target_q) + F.mse_loss(
current_q2, target_q
)
return loss_critic
def loss_actor(self, obs, alpha):
action, logprob = self.forward(
obs,
deterministic=False,
reparameterize=True,
get_logprob=True,
)
current_q1, current_q2 = self.critic(obs, action)
loss_actor = -torch.min(current_q1, current_q2) + alpha * logprob
return loss_actor.mean()
def loss_temperature(self, obs, alpha, target_entropy):
with torch.no_grad():
_, logprob = self.forward(
obs,
deterministic=False,
get_logprob=True,
)
loss_alpha = -torch.mean(alpha * (logprob + target_entropy))
return loss_alpha
def update_target_critic(self, tau):
for target_param, source_param in zip(
self.target_critic.parameters(), self.critic.parameters()
):
target_param.data.copy_(
target_param.data * (1.0 - tau) + source_param.data * tau
)