allow loading pre-trained weights (not fine-tuned) in DiffusionEvalFT
This commit is contained in:
parent
169a16dda7
commit
a746220905
@ -19,9 +19,9 @@ from model.diffusion.sampling import extract
|
|||||||
class DiffusionEvalFT(DiffusionModel):
|
class DiffusionEvalFT(DiffusionModel):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
use_ddim,
|
|
||||||
ft_denoising_steps,
|
|
||||||
network_path,
|
network_path,
|
||||||
|
use_ddim=False,
|
||||||
|
ft_denoising_steps=0, # if running fine-tuned model, need to specify the correct number of denoising steps fine-tuned, so that here it knows which model (base or ft) to use for each denoising step
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
# do not let base class load model
|
# do not let base class load model
|
||||||
@ -33,15 +33,30 @@ class DiffusionEvalFT(DiffusionModel):
|
|||||||
|
|
||||||
# Set up base model --- techncally not needed if all denoising steps are fine-tuned
|
# Set up base model --- techncally not needed if all denoising steps are fine-tuned
|
||||||
self.actor = self.network
|
self.actor = self.network
|
||||||
|
try:
|
||||||
base_weights = {
|
base_weights = {
|
||||||
key.split("actor.")[1]: checkpoint["model"][key]
|
key.split("actor.")[1]: checkpoint["model"][key]
|
||||||
for key in checkpoint["model"]
|
for key in checkpoint["model"]
|
||||||
if "actor." in key
|
if "actor." in key
|
||||||
}
|
}
|
||||||
|
use_ft = True
|
||||||
|
self.actor.load_state_dict(base_weights, strict=True)
|
||||||
|
except Exception:
|
||||||
|
assert ft_denoising_steps == 0, (
|
||||||
|
"If no base policy weights are found, ft_denoising_steps must be 0"
|
||||||
|
)
|
||||||
|
base_weights = {
|
||||||
|
key.split("network.")[1]: checkpoint["model"][key]
|
||||||
|
for key in checkpoint["model"]
|
||||||
|
if "network." in key
|
||||||
|
}
|
||||||
|
use_ft = False
|
||||||
|
logging.info("Actor weights not found. Using pre-trained weights!")
|
||||||
self.actor.load_state_dict(base_weights, strict=True)
|
self.actor.load_state_dict(base_weights, strict=True)
|
||||||
logging.info("Loaded base policy weights from %s", network_path)
|
logging.info("Loaded base policy weights from %s", network_path)
|
||||||
|
|
||||||
# Always set up fine-tuned model
|
# Always set up fine-tuned model
|
||||||
|
if use_ft:
|
||||||
self.actor_ft = copy.deepcopy(self.network)
|
self.actor_ft = copy.deepcopy(self.network)
|
||||||
ft_weights = {
|
ft_weights = {
|
||||||
key.split("actor_ft.")[1]: checkpoint["model"][key]
|
key.split("actor_ft.")[1]: checkpoint["model"][key]
|
||||||
|
Loading…
Reference in New Issue
Block a user