* Sampling over both env and denoising steps in DPPO updates (#13) * sample one from each chain * full random sampling * Add Proficient Human (PH) Configs and Pipeline (#16) * fix missing cfg * add ph config * fix how terminated flags are added to buffer in ibrl * add ph config * offline calql for 1M gradient updates * bug fix: number of calql online gradient steps is the number of new transitions collected * add sample config for DPPO with ta=1 * Sampling over both env and denoising steps in DPPO updates (#13) * sample one from each chain * full random sampling * fix diffusion loss when predicting initial noise * fix dppo inds * fix typo * remove print statement --------- Co-authored-by: Justin M. Lidard <jlidard@neuronic.cs.princeton.edu> Co-authored-by: allenzren <allen.ren@princeton.edu> * update robomimic configs * better calql formulation * optimize calql and ibrl training * optimize data transfer in ppo agents * add kitchen configs * re-organize config folders, rerun calql and rlpd * add scratch gym locomotion configs * add kitchen installation dependencies * use truncated for termination in furniture env * update furniture and gym configs * update README and dependencies with kitchen * add url for new data and checkpoints * update demo RL configs * update batch sizes for furniture unet configs * raise error about dropout in residual mlp * fix observation bug in bc loss --------- Co-authored-by: Justin Lidard <60638575+jlidard@users.noreply.github.com> Co-authored-by: Justin M. Lidard <jlidard@neuronic.cs.princeton.edu>
199 lines
6.8 KiB
Python
199 lines
6.8 KiB
Python
"""
|
|
Calibrated Conservative Q-Learning (CalQL) for Gaussian policy.
|
|
|
|
"""
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import logging
|
|
from copy import deepcopy
|
|
import numpy as np
|
|
import einops
|
|
|
|
from model.common.gaussian import GaussianModel
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class CalQL_Gaussian(GaussianModel):
|
|
def __init__(
|
|
self,
|
|
actor,
|
|
critic,
|
|
network_path=None,
|
|
cql_clip_diff_min=-np.inf,
|
|
cql_clip_diff_max=np.inf,
|
|
cql_min_q_weight=5.0,
|
|
cql_n_actions=10,
|
|
**kwargs,
|
|
):
|
|
super().__init__(network=actor, network_path=None, **kwargs)
|
|
self.cql_clip_diff_min = cql_clip_diff_min
|
|
self.cql_clip_diff_max = cql_clip_diff_max
|
|
self.cql_min_q_weight = cql_min_q_weight
|
|
self.cql_n_actions = cql_n_actions
|
|
|
|
# initialize critic networks
|
|
self.critic = critic.to(self.device)
|
|
self.target_critic = deepcopy(critic).to(self.device)
|
|
|
|
# Load pre-trained checkpoint - note we are also loading the pre-trained critic here
|
|
if network_path is not None:
|
|
checkpoint = torch.load(
|
|
network_path,
|
|
map_location=self.device,
|
|
weights_only=True,
|
|
)
|
|
self.load_state_dict(
|
|
checkpoint["model"],
|
|
strict=True,
|
|
)
|
|
log.info("Loaded actor from %s", network_path)
|
|
log.info(
|
|
f"Number of network parameters: {sum(p.numel() for p in self.parameters())}"
|
|
)
|
|
|
|
def loss_critic(
|
|
self,
|
|
obs,
|
|
next_obs,
|
|
actions,
|
|
random_actions,
|
|
rewards,
|
|
returns,
|
|
terminated,
|
|
gamma,
|
|
):
|
|
B = len(actions)
|
|
|
|
# Get initial TD loss
|
|
q_data1, q_data2 = self.critic(obs, actions)
|
|
with torch.no_grad():
|
|
# repeat for action samples
|
|
next_obs_repeated = {"state": next_obs["state"].repeat_interleave(
|
|
self.cql_n_actions, dim=0
|
|
)}
|
|
|
|
# Get the next actions and logprobs
|
|
next_actions, next_logprobs = self.forward(
|
|
next_obs_repeated,
|
|
deterministic=False,
|
|
get_logprob=True,
|
|
)
|
|
next_q1, next_q2 = self.target_critic(next_obs_repeated, next_actions)
|
|
next_q = torch.min(next_q1, next_q2)
|
|
|
|
# Reshape the next_q to match the number of samples
|
|
next_q = next_q.view(B, self.cql_n_actions) # (B, n_sample)
|
|
next_logprobs = next_logprobs.view(B, self.cql_n_actions) # (B, n_sample)
|
|
|
|
# Get the max indices over the samples, and index into the next_q and next_log_probs
|
|
max_idx = torch.argmax(next_q, dim=1)
|
|
next_q = next_q[torch.arange(B), max_idx]
|
|
next_logprobs = next_logprobs[torch.arange(B), max_idx]
|
|
|
|
# Get the target Q values
|
|
target_q = rewards + gamma * (1 - terminated) * next_q
|
|
|
|
# TD loss
|
|
td_loss_1 = nn.functional.mse_loss(q_data1, target_q)
|
|
td_loss_2 = nn.functional.mse_loss(q_data2, target_q)
|
|
|
|
# Get actions and logprobs
|
|
log_rand_pi = 0.5 ** torch.prod(torch.tensor(random_actions.shape[-2:]))
|
|
pi_actions, log_pi = self.forward(
|
|
obs,
|
|
deterministic=False,
|
|
reparameterize=False,
|
|
get_logprob=True,
|
|
) # no gradient
|
|
pi_next_actions, log_pi_next = self.forward(
|
|
next_obs,
|
|
deterministic=False,
|
|
reparameterize=False,
|
|
get_logprob=True,
|
|
) # no gradient
|
|
|
|
# Random action Q values
|
|
n_random_actions = random_actions.shape[1]
|
|
obs_sample_state = {
|
|
"state": obs["state"].repeat_interleave(n_random_actions, dim=0)
|
|
}
|
|
random_actions = einops.rearrange(random_actions, "B N H A -> (B N) H A")
|
|
|
|
# Get the random action Q-values
|
|
q_rand_1, q_rand_2 = self.critic(obs_sample_state, random_actions)
|
|
q_rand_1 = q_rand_1 - log_rand_pi
|
|
q_rand_2 = q_rand_2 - log_rand_pi
|
|
|
|
# Reshape the random action Q values to match the number of samples
|
|
q_rand_1 = q_rand_1.view(B, n_random_actions) # (n_sample, B)
|
|
q_rand_2 = q_rand_2.view(B, n_random_actions)
|
|
|
|
# Policy action Q values
|
|
q_pi_1, q_pi_2 = self.critic(obs, pi_actions)
|
|
q_pi_next_1, q_pi_next_2 = self.critic(next_obs, pi_next_actions)
|
|
|
|
# Ensure calibration w.r.t. value function estimate
|
|
q_pi_1 = torch.max(q_pi_1, returns)[:, None] # (B, 1)
|
|
q_pi_2 = torch.max(q_pi_2, returns)[:, None] # (B, 1)
|
|
q_pi_next_1 = torch.max(q_pi_next_1, returns)[:, None] # (B, 1)
|
|
q_pi_next_2 = torch.max(q_pi_next_2, returns)[:, None] # (B, 1)
|
|
|
|
# cql_importance_sample
|
|
q_pi_1 = q_pi_1 - log_pi
|
|
q_pi_2 = q_pi_2 - log_pi
|
|
q_pi_next_1 = q_pi_next_1 - log_pi_next
|
|
q_pi_next_2 = q_pi_next_2 - log_pi_next
|
|
cat_q_1 = torch.cat([q_rand_1, q_pi_1, q_pi_next_1], dim=-1) # (B, num_samples+1)
|
|
cql_qf1_ood = torch.logsumexp(cat_q_1, dim=-1) # max over num_samples
|
|
cat_q_2 = torch.cat([q_rand_2, q_pi_2, q_pi_next_2], dim=-1) # (B, num_samples+1)
|
|
cql_qf2_ood = torch.logsumexp(cat_q_2, dim=-1) # sum over num_samples
|
|
|
|
# skip cal_lagrange since the paper shows cql_target_action_gap not used in kitchen
|
|
|
|
# Subtract the log likelihood of the data
|
|
cql_qf1_diff = torch.clamp(
|
|
cql_qf1_ood - q_data1,
|
|
min=self.cql_clip_diff_min,
|
|
max=self.cql_clip_diff_max,
|
|
).mean()
|
|
cql_qf2_diff = torch.clamp(
|
|
cql_qf2_ood - q_data2,
|
|
min=self.cql_clip_diff_min,
|
|
max=self.cql_clip_diff_max,
|
|
).mean()
|
|
cql_min_qf1_loss = cql_qf1_diff * self.cql_min_q_weight
|
|
cql_min_qf2_loss = cql_qf2_diff * self.cql_min_q_weight
|
|
|
|
# Sum the two losses
|
|
critic_loss = td_loss_1 + td_loss_2 + cql_min_qf1_loss + cql_min_qf2_loss
|
|
return critic_loss
|
|
|
|
def loss_actor(self, obs, alpha):
|
|
action, logprob = self.forward(
|
|
obs,
|
|
deterministic=False,
|
|
reparameterize=True,
|
|
get_logprob=True,
|
|
)
|
|
q1, q2 = self.critic(obs, action)
|
|
actor_loss = -torch.min(q1, q2) + alpha * logprob
|
|
return actor_loss.mean()
|
|
|
|
def loss_temperature(self, obs, alpha, target_entropy):
|
|
with torch.no_grad():
|
|
_, logprob = self.forward(
|
|
obs,
|
|
deterministic=False,
|
|
get_logprob=True,
|
|
)
|
|
loss_alpha = -torch.mean(alpha * (logprob + target_entropy))
|
|
return loss_alpha
|
|
|
|
def update_target_critic(self, tau):
|
|
for target_param, param in zip(
|
|
self.target_critic.parameters(), self.critic.parameters()
|
|
):
|
|
target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
|