* v0.5 (#9) * update idql configs * update awr configs * update dipo configs * update qsm configs * update dqm configs * update project version to 0.5.0
89 lines
2.4 KiB
Python
89 lines
2.4 KiB
Python
"""
|
|
Soft Actor Critic (SAC) with Gaussian policy.
|
|
|
|
"""
|
|
|
|
import torch
|
|
import logging
|
|
from copy import deepcopy
|
|
import torch.nn.functional as F
|
|
|
|
from model.common.gaussian import GaussianModel
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class SAC_Gaussian(GaussianModel):
|
|
def __init__(
|
|
self,
|
|
actor,
|
|
critic,
|
|
**kwargs,
|
|
):
|
|
super().__init__(network=actor, **kwargs)
|
|
|
|
# initialize doubel critic networks
|
|
self.critic = critic.to(self.device)
|
|
|
|
# initialize double target networks
|
|
self.target_critic = deepcopy(self.critic).to(self.device)
|
|
|
|
def loss_critic(
|
|
self,
|
|
obs,
|
|
next_obs,
|
|
actions,
|
|
rewards,
|
|
terminated,
|
|
gamma,
|
|
alpha,
|
|
):
|
|
with torch.no_grad():
|
|
next_actions, next_logprobs = self.forward(
|
|
cond=next_obs,
|
|
deterministic=False,
|
|
get_logprob=True,
|
|
)
|
|
next_q1, next_q2 = self.target_critic(
|
|
next_obs,
|
|
next_actions,
|
|
)
|
|
next_q = torch.min(next_q1, next_q2) - alpha * next_logprobs
|
|
|
|
# target value
|
|
target_q = rewards + gamma * next_q * (1 - terminated)
|
|
current_q1, current_q2 = self.critic(obs, actions)
|
|
loss_critic = F.mse_loss(current_q1, target_q) + F.mse_loss(
|
|
current_q2, target_q
|
|
)
|
|
return loss_critic
|
|
|
|
def loss_actor(self, obs, alpha):
|
|
action, logprob = self.forward(
|
|
obs,
|
|
deterministic=False,
|
|
reparameterize=True,
|
|
get_logprob=True,
|
|
)
|
|
current_q1, current_q2 = self.critic(obs, action)
|
|
loss_actor = -torch.min(current_q1, current_q2) + alpha * logprob
|
|
return loss_actor.mean()
|
|
|
|
def loss_temperature(self, obs, alpha, target_entropy):
|
|
with torch.no_grad():
|
|
_, logprob = self.forward(
|
|
obs,
|
|
deterministic=False,
|
|
get_logprob=True,
|
|
)
|
|
loss_alpha = -torch.mean(alpha * (logprob + target_entropy))
|
|
return loss_alpha
|
|
|
|
def update_target_critic(self, tau):
|
|
for target_param, source_param in zip(
|
|
self.target_critic.parameters(), self.critic.parameters()
|
|
):
|
|
target_param.data.copy_(
|
|
target_param.data * (1.0 - tau) + source_param.data * tau
|
|
)
|