allow loading pre-trained weights (not fine-tuned) in DiffusionEvalFT
This commit is contained in:
parent
169a16dda7
commit
a746220905
@ -19,9 +19,9 @@ from model.diffusion.sampling import extract
|
||||
class DiffusionEvalFT(DiffusionModel):
|
||||
def __init__(
|
||||
self,
|
||||
use_ddim,
|
||||
ft_denoising_steps,
|
||||
network_path,
|
||||
use_ddim=False,
|
||||
ft_denoising_steps=0, # if running fine-tuned model, need to specify the correct number of denoising steps fine-tuned, so that here it knows which model (base or ft) to use for each denoising step
|
||||
**kwargs,
|
||||
):
|
||||
# do not let base class load model
|
||||
@ -33,23 +33,38 @@ class DiffusionEvalFT(DiffusionModel):
|
||||
|
||||
# Set up base model --- techncally not needed if all denoising steps are fine-tuned
|
||||
self.actor = self.network
|
||||
base_weights = {
|
||||
key.split("actor.")[1]: checkpoint["model"][key]
|
||||
for key in checkpoint["model"]
|
||||
if "actor." in key
|
||||
}
|
||||
self.actor.load_state_dict(base_weights, strict=True)
|
||||
try:
|
||||
base_weights = {
|
||||
key.split("actor.")[1]: checkpoint["model"][key]
|
||||
for key in checkpoint["model"]
|
||||
if "actor." in key
|
||||
}
|
||||
use_ft = True
|
||||
self.actor.load_state_dict(base_weights, strict=True)
|
||||
except Exception:
|
||||
assert ft_denoising_steps == 0, (
|
||||
"If no base policy weights are found, ft_denoising_steps must be 0"
|
||||
)
|
||||
base_weights = {
|
||||
key.split("network.")[1]: checkpoint["model"][key]
|
||||
for key in checkpoint["model"]
|
||||
if "network." in key
|
||||
}
|
||||
use_ft = False
|
||||
logging.info("Actor weights not found. Using pre-trained weights!")
|
||||
self.actor.load_state_dict(base_weights, strict=True)
|
||||
logging.info("Loaded base policy weights from %s", network_path)
|
||||
|
||||
# Always set up fine-tuned model
|
||||
self.actor_ft = copy.deepcopy(self.network)
|
||||
ft_weights = {
|
||||
key.split("actor_ft.")[1]: checkpoint["model"][key]
|
||||
for key in checkpoint["model"]
|
||||
if "actor_ft." in key
|
||||
}
|
||||
self.actor_ft.load_state_dict(ft_weights, strict=True)
|
||||
logging.info("Loaded fine-tuned policy weights from %s", network_path)
|
||||
if use_ft:
|
||||
self.actor_ft = copy.deepcopy(self.network)
|
||||
ft_weights = {
|
||||
key.split("actor_ft.")[1]: checkpoint["model"][key]
|
||||
for key in checkpoint["model"]
|
||||
if "actor_ft." in key
|
||||
}
|
||||
self.actor_ft.load_state_dict(ft_weights, strict=True)
|
||||
logging.info("Loaded fine-tuned policy weights from %s", network_path)
|
||||
|
||||
# override
|
||||
def p_mean_var(
|
||||
|
Loading…
Reference in New Issue
Block a user