""" Imitation Bootstrapped Reinforcement Learning (IBRL) for Gaussian policy. """ import torch import torch.nn as nn import logging from copy import deepcopy from model.common.gaussian import GaussianModel log = logging.getLogger(__name__) class IBRL_Gaussian(GaussianModel): def __init__( self, actor, critic, n_critics, soft_action_sample=False, soft_action_sample_beta=10, **kwargs, ): super().__init__(network=actor, **kwargs) self.soft_action_sample = soft_action_sample self.soft_action_sample_beta = soft_action_sample_beta # Set up target actor self.target_actor = deepcopy(actor) # Frozen pre-trained policy self.bc_policy = deepcopy(actor) for param in self.bc_policy.parameters(): param.requires_grad = False # initialize critic networks self.critic_networks = [ deepcopy(critic).to(self.device) for _ in range(n_critics) ] self.critic_networks = nn.ModuleList(self.critic_networks) # initialize target networks self.target_networks = [ deepcopy(critic).to(self.device) for _ in range(n_critics) ] self.target_networks = nn.ModuleList(self.target_networks) # Construct a "stateless" version of one of the models. It is "stateless" in the sense that the parameters are meta Tensors and do not have storage. base_model = deepcopy(self.critic_networks[0]) self.base_model = base_model.to("meta") self.ensemble_params, self.ensemble_buffers = torch.func.stack_module_state( self.critic_networks ) def critic_wrapper(self, params, buffers, data): """for vmap""" return torch.func.functional_call(self.base_model, (params, buffers), data) def get_random_indices(self, sz=None, num_ind=2): """get num_ind random indices from a set of size sz (used for getting critic targets)""" if sz is None: sz = len(self.critic_networks) perm = torch.randperm(sz) ind = perm[:num_ind].to(self.device) return ind def loss_critic( self, obs, next_obs, actions, rewards, terminated, gamma, ): # get random critic index q1_ind, q2_ind = self.get_random_indices() with torch.no_grad(): next_actions_bc = super().forward( cond=next_obs, deterministic=True, network_override=self.bc_policy, ) next_actions_rl = super().forward( cond=next_obs, deterministic=False, network_override=self.target_actor, ) # get the BC Q value next_q1_bc = self.target_networks[q1_ind](next_obs, next_actions_bc) next_q2_bc = self.target_networks[q2_ind](next_obs, next_actions_bc) next_q_bc = torch.min(next_q1_bc, next_q2_bc) # get the RL Q value next_q1_rl = self.target_networks[q1_ind](next_obs, next_actions_rl) next_q2_rl = self.target_networks[q2_ind](next_obs, next_actions_rl) next_q_rl = torch.min(next_q1_rl, next_q2_rl) # take the max Q value next_q = torch.where(next_q_bc > next_q_rl, next_q_bc, next_q_rl) # target value target_q = rewards + gamma * (1 - terminated) * next_q # (B,) # run all critics in batch current_q = torch.vmap(self.critic_wrapper, in_dims=(0, 0, None))( self.ensemble_params, self.ensemble_buffers, (obs, actions) ) # (n_critics, B) loss_critic = torch.mean((current_q - target_q[None]) ** 2) return loss_critic def loss_actor(self, obs): action = super().forward( obs, deterministic=False, reparameterize=True, ) # use online policy only, also IBRL does not use tanh squashing current_q = torch.vmap(self.critic_wrapper, in_dims=(0, 0, None))( self.ensemble_params, self.ensemble_buffers, (obs, action) ) # (n_critics, B) current_q = current_q.min( dim=0 ).values # unlike RLPD, IBRL uses the min Q value for actor update loss_actor = -torch.mean(current_q) return loss_actor def update_target_critic(self, tau): """need to use ensemble_params instead of critic_networks""" for target_ind, target_critic in enumerate(self.target_networks): for target_param_name, target_param in target_critic.named_parameters(): source_param = self.ensemble_params[target_param_name][target_ind] target_param.data.copy_( target_param.data * (1.0 - tau) + source_param.data * tau ) def update_target_actor(self, tau): for target_param, source_param in zip( self.target_actor.parameters(), self.network.parameters() ): target_param.data.copy_( target_param.data * (1.0 - tau) + source_param.data * tau ) # ---------- Sampling ----------# def forward( self, cond, deterministic=False, reparameterize=False, ): """use both pre-trained and online policies""" q1_ind, q2_ind = self.get_random_indices() # sample an action from the BC policy bc_action = super().forward( cond=cond, deterministic=True, network_override=self.bc_policy, ) # sample an action from the RL policy rl_action = super().forward( cond=cond, deterministic=deterministic, reparameterize=reparameterize, ) # compute Q value of BC policy q_bc_1 = self.critic_networks[q1_ind](cond, bc_action) # (B,) q_bc_2 = self.critic_networks[q2_ind](cond, bc_action) q_bc = torch.min(q_bc_1, q_bc_2) # compute Q value of RL policy q_rl_1 = self.critic_networks[q1_ind](cond, rl_action) q_rl_2 = self.critic_networks[q2_ind](cond, rl_action) q_rl = torch.min(q_rl_1, q_rl_2) # soft sample or greedy if deterministic or not self.soft_action_sample: action = torch.where( (q_bc > q_rl)[:, None, None], bc_action, rl_action, ) else: # compute the Q weights with probability proportional to exp(\beta * Q(a)) qw_bc = torch.exp(q_bc * self.soft_action_sample_beta) qw_rl = torch.exp(q_rl * self.soft_action_sample_beta) q_weights = torch.softmax( torch.stack([qw_bc, qw_rl], dim=-1), dim=-1, ) # sample according to the weights q_indices = torch.multinomial(q_weights, 1) action = torch.where( (q_indices == 0)[:, None], bc_action, rl_action, ) return action