From a746220905f6d18df0a7868bdd78070c37637c50 Mon Sep 17 00:00:00 2001 From: allenzren Date: Tue, 4 Feb 2025 11:39:56 -0500 Subject: [PATCH] allow loading pre-trained weights (not fine-tuned) in `DiffusionEvalFT` --- model/diffusion/diffusion_eval_ft.py | 47 ++++++++++++++++++---------- 1 file changed, 31 insertions(+), 16 deletions(-) diff --git a/model/diffusion/diffusion_eval_ft.py b/model/diffusion/diffusion_eval_ft.py index 99a7938..c5d4c69 100644 --- a/model/diffusion/diffusion_eval_ft.py +++ b/model/diffusion/diffusion_eval_ft.py @@ -19,9 +19,9 @@ from model.diffusion.sampling import extract class DiffusionEvalFT(DiffusionModel): def __init__( self, - use_ddim, - ft_denoising_steps, network_path, + use_ddim=False, + ft_denoising_steps=0, # if running fine-tuned model, need to specify the correct number of denoising steps fine-tuned, so that here it knows which model (base or ft) to use for each denoising step **kwargs, ): # do not let base class load model @@ -33,23 +33,38 @@ class DiffusionEvalFT(DiffusionModel): # Set up base model --- techncally not needed if all denoising steps are fine-tuned self.actor = self.network - base_weights = { - key.split("actor.")[1]: checkpoint["model"][key] - for key in checkpoint["model"] - if "actor." in key - } - self.actor.load_state_dict(base_weights, strict=True) + try: + base_weights = { + key.split("actor.")[1]: checkpoint["model"][key] + for key in checkpoint["model"] + if "actor." in key + } + use_ft = True + self.actor.load_state_dict(base_weights, strict=True) + except Exception: + assert ft_denoising_steps == 0, ( + "If no base policy weights are found, ft_denoising_steps must be 0" + ) + base_weights = { + key.split("network.")[1]: checkpoint["model"][key] + for key in checkpoint["model"] + if "network." in key + } + use_ft = False + logging.info("Actor weights not found. Using pre-trained weights!") + self.actor.load_state_dict(base_weights, strict=True) logging.info("Loaded base policy weights from %s", network_path) # Always set up fine-tuned model - self.actor_ft = copy.deepcopy(self.network) - ft_weights = { - key.split("actor_ft.")[1]: checkpoint["model"][key] - for key in checkpoint["model"] - if "actor_ft." in key - } - self.actor_ft.load_state_dict(ft_weights, strict=True) - logging.info("Loaded fine-tuned policy weights from %s", network_path) + if use_ft: + self.actor_ft = copy.deepcopy(self.network) + ft_weights = { + key.split("actor_ft.")[1]: checkpoint["model"][key] + for key in checkpoint["model"] + if "actor_ft." in key + } + self.actor_ft.load_state_dict(ft_weights, strict=True) + logging.info("Loaded fine-tuned policy weights from %s", network_path) # override def p_mean_var(