""" Gaussian policy parameterization. """ import torch import torch.distributions as D import logging log = logging.getLogger(__name__) class GaussianModel(torch.nn.Module): def __init__( self, network, horizon_steps, network_path=None, device="cuda:0", randn_clip_value=10, ): super().__init__() self.device = device self.network = network.to(device) if network_path is not None: checkpoint = torch.load( network_path, map_location=self.device, weights_only=True ) self.load_state_dict( checkpoint["model"], strict=False, ) log.info("Loaded actor from %s", network_path) log.info( f"Number of network parameters: {sum(p.numel() for p in self.parameters())}" ) self.horizon_steps = horizon_steps # Clip sampled randn (from standard deviation) such that the sampled action is not too far away from mean self.randn_clip_value = randn_clip_value def loss( self, true_action, cond, ent_coef, ): B = len(true_action) dist = self.forward_train( cond, deterministic=False, ) true_action = true_action.view(B, -1) loss = -dist.log_prob(true_action) # [B] entropy = dist.entropy().mean() loss = loss.mean() - entropy * ent_coef return loss, {"entropy": entropy} def forward_train( self, cond, deterministic=False, network_override=None, ): """ Calls the MLP to compute the mean, scale, and logits of the GMM. Returns the torch.Distribution object. """ if network_override is not None: means, scales = network_override(cond) else: means, scales = self.network(cond) if deterministic: # low-noise for all Gaussian dists scales = torch.ones_like(means) * 1e-4 return D.Normal(loc=means, scale=scales) def forward( self, cond, deterministic=False, network_override=None, ): B = len(cond["state"]) if "state" in cond else len(cond["rgb"]) T = self.horizon_steps dist = self.forward_train( cond, deterministic=deterministic, network_override=network_override, ) sampled_action = dist.sample() sampled_action.clamp_( dist.loc - self.randn_clip_value * dist.scale, dist.loc + self.randn_clip_value * dist.scale, ) return sampled_action.view(B, T, -1)