dppo/model/rl/gaussian_ibrl.py
Allen Z. Ren dc8e0c9edc
v0.6 (#18)
* Sampling over both env and denoising steps in DPPO updates (#13)

* sample one from each chain

* full random sampling

* Add Proficient Human (PH) Configs and Pipeline (#16)

* fix missing cfg

* add ph config

* fix how terminated flags are added to buffer in ibrl

* add ph config

* offline calql for 1M gradient updates

* bug fix: number of calql online gradient steps is the number of new transitions collected

* add sample config for DPPO with ta=1

* Sampling over both env and denoising steps in DPPO updates (#13)

* sample one from each chain

* full random sampling

* fix diffusion loss when predicting initial noise

* fix dppo inds

* fix typo

* remove print statement

---------

Co-authored-by: Justin M. Lidard <jlidard@neuronic.cs.princeton.edu>
Co-authored-by: allenzren <allen.ren@princeton.edu>

* update robomimic configs

* better calql formulation

* optimize calql and ibrl training

* optimize data transfer in ppo agents

* add kitchen configs

* re-organize config folders, rerun calql and rlpd

* add scratch gym locomotion configs

* add kitchen installation dependencies

* use truncated for termination in furniture env

* update furniture and gym configs

* update README and dependencies with kitchen

* add url for new data and checkpoints

* update demo RL configs

* update batch sizes for furniture unet configs

* raise error about dropout in residual mlp

* fix observation bug in bc loss

---------

Co-authored-by: Justin Lidard <60638575+jlidard@users.noreply.github.com>
Co-authored-by: Justin M. Lidard <jlidard@neuronic.cs.princeton.edu>
2024-10-30 19:58:06 -04:00

206 lines
7.0 KiB
Python

"""
Imitation Bootstrapped Reinforcement Learning (IBRL) for Gaussian policy.
"""
import torch
import torch.nn as nn
import logging
from copy import deepcopy
from model.common.gaussian import GaussianModel
log = logging.getLogger(__name__)
class IBRL_Gaussian(GaussianModel):
def __init__(
self,
actor,
critic,
n_critics,
soft_action_sample=False,
soft_action_sample_beta=10,
**kwargs,
):
super().__init__(network=actor, **kwargs)
self.soft_action_sample = soft_action_sample
self.soft_action_sample_beta = soft_action_sample_beta
# Set up target actor
self.target_actor = deepcopy(actor)
# Frozen pre-trained policy
self.bc_policy = deepcopy(actor)
for param in self.bc_policy.parameters():
param.requires_grad = False
# initialize critic networks
self.critic_networks = [
deepcopy(critic).to(self.device) for _ in range(n_critics)
]
self.critic_networks = nn.ModuleList(self.critic_networks)
# initialize target networks
self.target_networks = [
deepcopy(critic).to(self.device) for _ in range(n_critics)
]
self.target_networks = nn.ModuleList(self.target_networks)
# Construct a "stateless" version of one of the models. It is "stateless" in the sense that the parameters are meta Tensors and do not have storage.
base_model = deepcopy(self.critic_networks[0])
self.base_model = base_model.to("meta")
self.ensemble_params, self.ensemble_buffers = torch.func.stack_module_state(
self.critic_networks
)
def critic_wrapper(self, params, buffers, data):
"""for vmap"""
return torch.func.functional_call(self.base_model, (params, buffers), data)
def get_random_indices(self, sz=None, num_ind=2):
"""get num_ind random indices from a set of size sz (used for getting critic targets)"""
if sz is None:
sz = len(self.critic_networks)
perm = torch.randperm(sz)
ind = perm[:num_ind].to(self.device)
return ind
def loss_critic(
self,
obs,
next_obs,
actions,
rewards,
terminated,
gamma,
):
# get random critic index
q1_ind, q2_ind = self.get_random_indices()
with torch.no_grad():
next_actions_bc = super().forward(
cond=next_obs,
deterministic=True,
network_override=self.bc_policy,
)
next_actions_rl = super().forward(
cond=next_obs,
deterministic=False,
network_override=self.target_actor,
)
# get the BC Q value
next_q1_bc = self.target_networks[q1_ind](next_obs, next_actions_bc)
next_q2_bc = self.target_networks[q2_ind](next_obs, next_actions_bc)
next_q_bc = torch.min(next_q1_bc, next_q2_bc)
# get the RL Q value
next_q1_rl = self.target_networks[q1_ind](next_obs, next_actions_rl)
next_q2_rl = self.target_networks[q2_ind](next_obs, next_actions_rl)
next_q_rl = torch.min(next_q1_rl, next_q2_rl)
# take the max Q value
next_q = torch.where(next_q_bc > next_q_rl, next_q_bc, next_q_rl)
# target value
target_q = rewards + gamma * (1 - terminated) * next_q # (B,)
# run all critics in batch
current_q = torch.vmap(self.critic_wrapper, in_dims=(0, 0, None))(
self.ensemble_params, self.ensemble_buffers, (obs, actions)
) # (n_critics, B)
loss_critic = torch.mean((current_q - target_q[None]) ** 2)
return loss_critic
def loss_actor(self, obs):
action = super().forward(
obs,
deterministic=False,
reparameterize=True,
) # use online policy only, also IBRL does not use tanh squashing
current_q = torch.vmap(self.critic_wrapper, in_dims=(0, 0, None))(
self.ensemble_params, self.ensemble_buffers, (obs, action)
) # (n_critics, B)
current_q = current_q.min(
dim=0
).values # unlike RLPD, IBRL uses the min Q value for actor update
loss_actor = -torch.mean(current_q)
return loss_actor
def update_target_critic(self, tau):
"""need to use ensemble_params instead of critic_networks"""
for target_ind, target_critic in enumerate(self.target_networks):
for target_param_name, target_param in target_critic.named_parameters():
source_param = self.ensemble_params[target_param_name][target_ind]
target_param.data.copy_(
target_param.data * (1.0 - tau) + source_param.data * tau
)
def update_target_actor(self, tau):
for target_param, source_param in zip(
self.target_actor.parameters(), self.network.parameters()
):
target_param.data.copy_(
target_param.data * (1.0 - tau) + source_param.data * tau
)
# ---------- Sampling ----------#
def forward(
self,
cond,
deterministic=False,
reparameterize=False,
):
"""use both pre-trained and online policies"""
q1_ind, q2_ind = self.get_random_indices()
# sample an action from the BC policy
bc_action = super().forward(
cond=cond,
deterministic=True,
network_override=self.bc_policy,
)
# sample an action from the RL policy
rl_action = super().forward(
cond=cond,
deterministic=deterministic,
reparameterize=reparameterize,
)
# compute Q value of BC policy
q_bc_1 = self.critic_networks[q1_ind](cond, bc_action) # (B,)
q_bc_2 = self.critic_networks[q2_ind](cond, bc_action)
q_bc = torch.min(q_bc_1, q_bc_2)
# compute Q value of RL policy
q_rl_1 = self.critic_networks[q1_ind](cond, rl_action)
q_rl_2 = self.critic_networks[q2_ind](cond, rl_action)
q_rl = torch.min(q_rl_1, q_rl_2)
# soft sample or greedy
if deterministic or not self.soft_action_sample:
action = torch.where(
(q_bc > q_rl)[:, None, None],
bc_action,
rl_action,
)
else:
# compute the Q weights with probability proportional to exp(\beta * Q(a))
qw_bc = torch.exp(q_bc * self.soft_action_sample_beta)
qw_rl = torch.exp(q_rl * self.soft_action_sample_beta)
q_weights = torch.softmax(
torch.stack([qw_bc, qw_rl], dim=-1),
dim=-1,
)
# sample according to the weights
q_indices = torch.multinomial(q_weights, 1)
action = torch.where(
(q_indices == 0)[:, None],
bc_action,
rl_action,
)
return action