add separate eval model class that also initializes the pre-trained policy for early denoising steps
This commit is contained in:
parent
e7f73dffc1
commit
ace2bbdab9
@ -138,7 +138,7 @@ See [here](cfg/finetuning.md) for details of the experiments in the paper.
|
||||
* Videos of trials in Robomimic tasks can be recorded by specifying `env.save_video=True`, `train.render.freq=<iterations>`, and `train.render.num=<num_video>` in fine-tuning configs.
|
||||
|
||||
## Usage - Evaluation
|
||||
Pre-trained or fine-tuned policies can be evaluated without running the fine-tuning script now. Some example configs are provided under `cfg/{gym/robomimic/furniture}/eval}` including ones below. Set `base_policy_path` to override the default checkpoint.
|
||||
Pre-trained or fine-tuned policies can be evaluated without running the fine-tuning script now. Some example configs are provided under `cfg/{gym/robomimic/furniture}/eval}` including ones below. `ft_denoising_steps` needs to match fine-tuning config. Set `base_policy_path` to override the default checkpoint.
|
||||
```console
|
||||
python script/run.py --config-name=eval_diffusion_mlp \
|
||||
--config-dir=cfg/gym/eval/hopper-v2
|
||||
|
@ -128,7 +128,7 @@ class TrainAgent:
|
||||
data = {
|
||||
"itr": self.itr,
|
||||
"model": self.model.state_dict(),
|
||||
}
|
||||
} # right now `model` includes weights for `network`, `actor`, `actor_ft`. Weights for `network` is redundant, and we can use `actor` weights as the base policy (earlier denoising steps) and `actor_ft` weights as the fine-tuned policy (later denoising steps) during evaluation.
|
||||
savepath = os.path.join(self.checkpoint_dir, f"state_{self.itr}.pt")
|
||||
torch.save(data, savepath)
|
||||
log.info(f"Saved model to {savepath}")
|
||||
|
@ -289,6 +289,7 @@ class DiffusionModel(nn.Module):
|
||||
t=t_b,
|
||||
cond=cond,
|
||||
index=index_b,
|
||||
deterministic=deterministic,
|
||||
)
|
||||
std = torch.exp(0.5 * logvar)
|
||||
|
||||
|
135
model/diffusion/diffusion_eval_ft.py
Normal file
135
model/diffusion/diffusion_eval_ft.py
Normal file
@ -0,0 +1,135 @@
|
||||
"""
|
||||
For evaluating RL fine-tuned diffusion policy
|
||||
|
||||
Account for frozen base policy for early denoising steps and fine-tuned policy for later denoising steps
|
||||
|
||||
"""
|
||||
|
||||
import copy
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
from model.diffusion.diffusion import DiffusionModel
|
||||
from model.diffusion.sampling import extract
|
||||
|
||||
|
||||
class DiffusionEvalFT(DiffusionModel):
|
||||
def __init__(
|
||||
self,
|
||||
use_ddim,
|
||||
ft_denoising_steps,
|
||||
network_path,
|
||||
**kwargs,
|
||||
):
|
||||
# do not let base class load model
|
||||
super().__init__(use_ddim=use_ddim, network_path=None, **kwargs)
|
||||
self.ft_denoising_steps = ft_denoising_steps
|
||||
checkpoint = torch.load(
|
||||
network_path, map_location=self.device, weights_only=True
|
||||
) # 'network.mlp_mean...', 'actor.mlp_mean...', 'actor_ft.mlp_mean...'
|
||||
|
||||
# Set up base model --- techncally not needed if all denoising steps are fine-tuned
|
||||
self.actor = self.network
|
||||
base_weights = {
|
||||
key.split("actor.")[1]: checkpoint["model"][key]
|
||||
for key in checkpoint["model"]
|
||||
if "actor." in key
|
||||
}
|
||||
self.actor.load_state_dict(base_weights, strict=True)
|
||||
logging.info("Loaded base policy weights from %s", network_path)
|
||||
|
||||
# Always set up fine-tuned model
|
||||
self.actor_ft = copy.deepcopy(self.network)
|
||||
ft_weights = {
|
||||
key.split("actor_ft.")[1]: checkpoint["model"][key]
|
||||
for key in checkpoint["model"]
|
||||
if "actor_ft." in key
|
||||
}
|
||||
self.actor_ft.load_state_dict(ft_weights, strict=True)
|
||||
logging.info("Loaded fine-tuned policy weights from %s", network_path)
|
||||
|
||||
# override
|
||||
def p_mean_var(
|
||||
self,
|
||||
x,
|
||||
t,
|
||||
cond,
|
||||
index=None,
|
||||
deterministic=False,
|
||||
):
|
||||
noise = self.actor(x, t, cond=cond)
|
||||
if self.use_ddim:
|
||||
ft_indices = torch.where(
|
||||
index >= (self.ddim_steps - self.ft_denoising_steps)
|
||||
)[0]
|
||||
else:
|
||||
ft_indices = torch.where(t < self.ft_denoising_steps)[0]
|
||||
|
||||
# overwrite noise for fine-tuning steps
|
||||
if len(ft_indices) > 0:
|
||||
cond_ft = {key: cond[key][ft_indices] for key in cond}
|
||||
noise_ft = self.actor_ft(x[ft_indices], t[ft_indices], cond=cond_ft)
|
||||
noise[ft_indices] = noise_ft
|
||||
|
||||
# Predict x_0
|
||||
if self.predict_epsilon:
|
||||
if self.use_ddim:
|
||||
"""
|
||||
x₀ = (xₜ - √ (1-αₜ) ε )/ √ αₜ
|
||||
"""
|
||||
alpha = extract(self.ddim_alphas, index, x.shape)
|
||||
alpha_prev = extract(self.ddim_alphas_prev, index, x.shape)
|
||||
sqrt_one_minus_alpha = extract(
|
||||
self.ddim_sqrt_one_minus_alphas, index, x.shape
|
||||
)
|
||||
x_recon = (x - sqrt_one_minus_alpha * noise) / (alpha**0.5)
|
||||
else:
|
||||
"""
|
||||
x₀ = √ 1\α̅ₜ xₜ - √ 1\α̅ₜ-1 ε
|
||||
"""
|
||||
x_recon = (
|
||||
extract(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
|
||||
- extract(self.sqrt_recipm1_alphas_cumprod, t, x.shape) * noise
|
||||
)
|
||||
else: # directly predicting x₀
|
||||
x_recon = noise
|
||||
if self.denoised_clip_value is not None:
|
||||
x_recon.clamp_(-self.denoised_clip_value, self.denoised_clip_value)
|
||||
if self.use_ddim:
|
||||
# re-calculate noise based on clamped x_recon - default to false in HF, but let's use it here
|
||||
noise = (x - alpha ** (0.5) * x_recon) / sqrt_one_minus_alpha
|
||||
|
||||
# Clip epsilon for numerical stability in policy gradient - not sure if this is helpful yet, but the value can be huge sometimes. This has no effect if DDPM is used
|
||||
if self.use_ddim and self.eps_clip_value is not None:
|
||||
noise.clamp_(-self.eps_clip_value, self.eps_clip_value)
|
||||
|
||||
# Get mu
|
||||
if self.use_ddim:
|
||||
"""
|
||||
μ = √ αₜ₋₁ x₀ + √(1-αₜ₋₁ - σₜ²) ε
|
||||
"""
|
||||
if deterministic:
|
||||
etas = torch.zeros((x.shape[0], 1, 1)).to(x.device)
|
||||
else:
|
||||
etas = self.eta(cond).unsqueeze(1) # B x 1 x (Da or 1)
|
||||
sigma = (
|
||||
etas
|
||||
* ((1 - alpha_prev) / (1 - alpha) * (1 - alpha / alpha_prev)) ** 0.5
|
||||
).clamp_(min=1e-10)
|
||||
dir_xt_coef = (1.0 - alpha_prev - sigma**2).clamp_(min=0).sqrt()
|
||||
mu = (alpha_prev**0.5) * x_recon + dir_xt_coef * noise
|
||||
var = sigma**2
|
||||
logvar = torch.log(var)
|
||||
else:
|
||||
"""
|
||||
μₜ = β̃ₜ √ α̅ₜ₋₁/(1-α̅ₜ)x₀ + √ αₜ (1-α̅ₜ₋₁)/(1-α̅ₜ)xₜ
|
||||
"""
|
||||
mu = (
|
||||
extract(self.ddpm_mu_coef1, t, x.shape) * x_recon
|
||||
+ extract(self.ddpm_mu_coef2, t, x.shape) * x
|
||||
)
|
||||
logvar = extract(self.ddpm_logvar_clipped, t, x.shape)
|
||||
return mu, logvar
|
Loading…
Reference in New Issue
Block a user