47 lines
1.0 KiB
Python
47 lines
1.0 KiB
Python
import torch
|
|
import logging
|
|
from model.common.gmm import GMMModel
|
|
|
|
|
|
class VPG_GMM(GMMModel):
|
|
def __init__(
|
|
self,
|
|
actor,
|
|
critic,
|
|
**kwargs,
|
|
):
|
|
super().__init__(network=actor, **kwargs)
|
|
|
|
# Re-name network to actor
|
|
self.actor_ft = actor
|
|
|
|
# Value function for obs - simple MLP
|
|
self.critic = critic.to(self.device)
|
|
|
|
# ---------- Sampling ----------#
|
|
|
|
@torch.no_grad()
|
|
def forward(self, cond, deterministic=False):
|
|
return super().forward(
|
|
cond=cond,
|
|
deterministic=deterministic,
|
|
)
|
|
|
|
# ---------- RL training ----------#
|
|
|
|
def get_logprobs(
|
|
self,
|
|
cond,
|
|
actions,
|
|
):
|
|
B = len(actions)
|
|
dist, entropy, std = self.forward_train(
|
|
cond,
|
|
deterministic=False,
|
|
)
|
|
log_prob = dist.log_prob(actions.view(B, -1))
|
|
return log_prob, entropy, std
|
|
|
|
def loss(self, obs, chains, reward):
|
|
raise NotImplementedError
|