allow loading pre-trained weights (not fine-tuned) in DiffusionEvalFT

This commit is contained in:
allenzren 2025-02-04 11:39:56 -05:00 committed by Allen Z. Ren
parent 169a16dda7
commit a746220905

View File

@ -19,9 +19,9 @@ from model.diffusion.sampling import extract
class DiffusionEvalFT(DiffusionModel):
def __init__(
self,
use_ddim,
ft_denoising_steps,
network_path,
use_ddim=False,
ft_denoising_steps=0, # if running fine-tuned model, need to specify the correct number of denoising steps fine-tuned, so that here it knows which model (base or ft) to use for each denoising step
**kwargs,
):
# do not let base class load model
@ -33,23 +33,38 @@ class DiffusionEvalFT(DiffusionModel):
# Set up base model --- techncally not needed if all denoising steps are fine-tuned
self.actor = self.network
base_weights = {
key.split("actor.")[1]: checkpoint["model"][key]
for key in checkpoint["model"]
if "actor." in key
}
self.actor.load_state_dict(base_weights, strict=True)
try:
base_weights = {
key.split("actor.")[1]: checkpoint["model"][key]
for key in checkpoint["model"]
if "actor." in key
}
use_ft = True
self.actor.load_state_dict(base_weights, strict=True)
except Exception:
assert ft_denoising_steps == 0, (
"If no base policy weights are found, ft_denoising_steps must be 0"
)
base_weights = {
key.split("network.")[1]: checkpoint["model"][key]
for key in checkpoint["model"]
if "network." in key
}
use_ft = False
logging.info("Actor weights not found. Using pre-trained weights!")
self.actor.load_state_dict(base_weights, strict=True)
logging.info("Loaded base policy weights from %s", network_path)
# Always set up fine-tuned model
self.actor_ft = copy.deepcopy(self.network)
ft_weights = {
key.split("actor_ft.")[1]: checkpoint["model"][key]
for key in checkpoint["model"]
if "actor_ft." in key
}
self.actor_ft.load_state_dict(ft_weights, strict=True)
logging.info("Loaded fine-tuned policy weights from %s", network_path)
if use_ft:
self.actor_ft = copy.deepcopy(self.network)
ft_weights = {
key.split("actor_ft.")[1]: checkpoint["model"][key]
for key in checkpoint["model"]
if "actor_ft." in key
}
self.actor_ft.load_state_dict(ft_weights, strict=True)
logging.info("Loaded fine-tuned policy weights from %s", network_path)
# override
def p_mean_var(