dppo/model/rl/gmm_vpg.py
2024-09-11 21:09:17 -04:00

47 lines
1.0 KiB
Python

import torch
import logging
from model.common.gmm import GMMModel
class VPG_GMM(GMMModel):
def __init__(
self,
actor,
critic,
**kwargs,
):
super().__init__(network=actor, **kwargs)
# Re-name network to actor
self.actor_ft = actor
# Value function for obs - simple MLP
self.critic = critic.to(self.device)
# ---------- Sampling ----------#
@torch.no_grad()
def forward(self, cond, deterministic=False):
return super().forward(
cond=cond,
deterministic=deterministic,
)
# ---------- RL training ----------#
def get_logprobs(
self,
cond,
actions,
):
B = len(actions)
dist, entropy, std = self.forward_train(
cond,
deterministic=False,
)
log_prob = dist.log_prob(actions.view(B, -1))
return log_prob, entropy, std
def loss(self, obs, chains, reward):
raise NotImplementedError