More frequent EMA update (#20)
* move ema update within pretraining epoch * update pretraining ema configs * add lift and can epoch 8000 checkpoint url * add note about EMA issue in pretraining instruction
This commit is contained in:
parent
dc8e0c9edc
commit
e1ef4ca1cf
@ -36,6 +36,10 @@ class TrainDiffusionAgent(PreTrainAgent):
|
||||
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# update ema
|
||||
if self.epoch % self.update_ema_freq == 0:
|
||||
self.step_ema()
|
||||
loss_train = np.mean(loss_train_epoch)
|
||||
|
||||
# validate
|
||||
@ -53,10 +57,6 @@ class TrainDiffusionAgent(PreTrainAgent):
|
||||
# update lr
|
||||
self.lr_scheduler.step()
|
||||
|
||||
# update ema
|
||||
if self.epoch % self.update_ema_freq == 0:
|
||||
self.step_ema()
|
||||
|
||||
# save model
|
||||
if self.epoch % self.save_model_freq == 0 or self.epoch == self.n_epochs:
|
||||
self.save_model()
|
||||
|
@ -44,6 +44,10 @@ class TrainGaussianAgent(PreTrainAgent):
|
||||
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# update ema
|
||||
if self.epoch % self.update_ema_freq == 0:
|
||||
self.step_ema()
|
||||
loss_train = np.mean(loss_train_epoch)
|
||||
ent_train = np.mean(ent_train_epoch)
|
||||
|
||||
@ -65,10 +69,6 @@ class TrainGaussianAgent(PreTrainAgent):
|
||||
# update lr
|
||||
self.lr_scheduler.step()
|
||||
|
||||
# update ema
|
||||
if self.epoch % self.update_ema_freq == 0:
|
||||
self.step_ema()
|
||||
|
||||
# save model
|
||||
if self.epoch % self.save_model_freq == 0 or self.epoch == self.n_epochs:
|
||||
self.save_model()
|
||||
|
@ -34,8 +34,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.diffusion.diffusion.DiffusionModel
|
||||
|
@ -33,8 +33,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.common.gaussian.GaussianModel
|
||||
|
@ -34,8 +34,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.common.gmm.GMMModel
|
||||
|
@ -34,8 +34,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.diffusion.diffusion.DiffusionModel
|
||||
|
@ -33,8 +33,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.common.gaussian.GaussianModel
|
||||
|
@ -34,8 +34,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.common.gmm.GMMModel
|
||||
|
@ -34,8 +34,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.diffusion.diffusion.DiffusionModel
|
||||
|
@ -33,8 +33,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.common.gaussian.GaussianModel
|
||||
|
@ -34,8 +34,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.common.gmm.GMMModel
|
||||
|
@ -35,8 +35,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.diffusion.diffusion.DiffusionModel
|
||||
|
@ -35,8 +35,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.diffusion.diffusion.DiffusionModel
|
||||
|
@ -34,8 +34,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.common.gaussian.GaussianModel
|
||||
|
@ -35,8 +35,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.diffusion.diffusion.DiffusionModel
|
||||
|
@ -35,8 +35,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.diffusion.diffusion.DiffusionModel
|
||||
|
@ -34,8 +34,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.common.gaussian.GaussianModel
|
||||
|
@ -35,8 +35,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.diffusion.diffusion.DiffusionModel
|
||||
|
@ -35,8 +35,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.diffusion.diffusion.DiffusionModel
|
||||
|
@ -34,8 +34,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.common.gaussian.GaussianModel
|
||||
|
@ -35,8 +35,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.diffusion.diffusion.DiffusionModel
|
||||
|
@ -35,8 +35,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.diffusion.diffusion.DiffusionModel
|
||||
|
@ -34,8 +34,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.common.gaussian.GaussianModel
|
||||
|
@ -35,8 +35,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.diffusion.diffusion.DiffusionModel
|
||||
|
@ -35,8 +35,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.diffusion.diffusion.DiffusionModel
|
||||
|
@ -34,8 +34,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.common.gaussian.GaussianModel
|
||||
|
@ -35,8 +35,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.diffusion.diffusion.DiffusionModel
|
||||
|
@ -35,8 +35,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.diffusion.diffusion.DiffusionModel
|
||||
|
@ -34,8 +34,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.common.gaussian.GaussianModel
|
||||
|
@ -32,8 +32,8 @@ train:
|
||||
first_cycle_steps: 3000
|
||||
warmup_steps: 1
|
||||
min_lr: 1e-4
|
||||
epoch_start_ema: 10
|
||||
update_ema_freq: 5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 100
|
||||
|
||||
model:
|
||||
|
@ -31,8 +31,8 @@ train:
|
||||
first_cycle_steps: 1000
|
||||
warmup_steps: 1
|
||||
min_lr: 1e-4
|
||||
epoch_start_ema: 10
|
||||
update_ema_freq: 5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 100
|
||||
|
||||
model:
|
||||
|
@ -32,8 +32,8 @@ train:
|
||||
first_cycle_steps: 3000
|
||||
warmup_steps: 1
|
||||
min_lr: 1e-4
|
||||
epoch_start_ema: 10
|
||||
update_ema_freq: 5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 100
|
||||
|
||||
model:
|
||||
|
@ -31,8 +31,8 @@ train:
|
||||
first_cycle_steps: 1000
|
||||
warmup_steps: 1
|
||||
min_lr: 1e-4
|
||||
epoch_start_ema: 10
|
||||
update_ema_freq: 5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 100
|
||||
|
||||
model:
|
||||
|
@ -32,9 +32,9 @@ train:
|
||||
first_cycle_steps: 8000
|
||||
warmup_steps: 1
|
||||
min_lr: 1e-4
|
||||
epoch_start_ema: 10
|
||||
update_ema_freq: 5
|
||||
save_model_freq: 1000
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.diffusion.diffusion.DiffusionModel
|
||||
|
@ -32,8 +32,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-4
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.common.gaussian.GaussianModel
|
||||
|
@ -32,9 +32,9 @@ train:
|
||||
first_cycle_steps: 8000
|
||||
warmup_steps: 1
|
||||
min_lr: 1e-4
|
||||
epoch_start_ema: 10
|
||||
update_ema_freq: 5
|
||||
save_model_freq: 1000
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.diffusion.diffusion.DiffusionModel
|
||||
|
@ -31,9 +31,9 @@ train:
|
||||
first_cycle_steps: 5000
|
||||
warmup_steps: 1
|
||||
min_lr: 1e-4
|
||||
epoch_start_ema: 10
|
||||
update_ema_freq: 5
|
||||
save_model_freq: 1000
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.common.gaussian.GaussianModel
|
||||
|
@ -32,9 +32,9 @@ train:
|
||||
first_cycle_steps: 8000
|
||||
warmup_steps: 1
|
||||
min_lr: 1e-4
|
||||
epoch_start_ema: 10
|
||||
update_ema_freq: 5
|
||||
save_model_freq: 1000
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.diffusion.diffusion.DiffusionModel
|
||||
|
@ -31,9 +31,9 @@ train:
|
||||
first_cycle_steps: 5000
|
||||
warmup_steps: 1
|
||||
min_lr: 1e-4
|
||||
epoch_start_ema: 10
|
||||
update_ema_freq: 5
|
||||
save_model_freq: 1000
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.common.gaussian.GaussianModel
|
||||
|
@ -32,8 +32,8 @@ train:
|
||||
first_cycle_steps: 3000
|
||||
warmup_steps: 1
|
||||
min_lr: 1e-4
|
||||
epoch_start_ema: 10
|
||||
update_ema_freq: 5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 100
|
||||
|
||||
model:
|
||||
|
@ -31,8 +31,8 @@ train:
|
||||
first_cycle_steps: 3000
|
||||
warmup_steps: 1
|
||||
min_lr: 1e-4
|
||||
epoch_start_ema: 10
|
||||
update_ema_freq: 5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 100
|
||||
|
||||
model:
|
||||
|
@ -1,5 +1,7 @@
|
||||
## Pre-training experiments
|
||||
|
||||
**Update, Nov 6 2024**: we fixed the issue of EMA update being too infrequent causing slow pre-training. Now the number of epochs needed for pre-training can be much slower than those used in the configs. We recommend training with fewer epochs and testing the early checkpoints.
|
||||
|
||||
### Comparing diffusion-based RL algorithms (Sec. 5.1)
|
||||
Gym configs are under `cfg/gym/pretrain/<env_name>/`, and the config name is `pre_diffusion_mlp`. Robomimic configs are under `cfg/robomimic/pretrain/<env_name>/`, and the name is also `pre_diffusion_mlp`.
|
||||
|
||||
|
@ -62,7 +62,7 @@ train:
|
||||
first_cycle_steps: 1000
|
||||
warmup_steps: 10
|
||||
min_lr: 1e-4
|
||||
save_model_freq: 100000
|
||||
save_model_freq: 50000
|
||||
val_freq: 10000
|
||||
render:
|
||||
freq: 10000
|
||||
|
@ -62,7 +62,7 @@ train:
|
||||
first_cycle_steps: 1000
|
||||
warmup_steps: 10
|
||||
min_lr: 1e-4
|
||||
save_model_freq: 100000
|
||||
save_model_freq: 50000
|
||||
val_freq: 10000
|
||||
render:
|
||||
freq: 10000
|
||||
|
@ -62,7 +62,7 @@ train:
|
||||
first_cycle_steps: 1000
|
||||
warmup_steps: 10
|
||||
min_lr: 1e-4
|
||||
save_model_freq: 100000
|
||||
save_model_freq: 50000
|
||||
val_freq: 10000
|
||||
render:
|
||||
freq: 10000
|
||||
|
@ -62,7 +62,7 @@ train:
|
||||
first_cycle_steps: 1000
|
||||
warmup_steps: 10
|
||||
min_lr: 1e-4
|
||||
save_model_freq: 100000
|
||||
save_model_freq: 50000
|
||||
val_freq: 10000
|
||||
render:
|
||||
freq: 10000
|
||||
|
@ -33,8 +33,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.diffusion.diffusion.DiffusionModel
|
||||
|
@ -43,7 +43,7 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
|
@ -33,8 +33,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.diffusion.diffusion.DiffusionModel
|
||||
|
@ -33,8 +33,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.diffusion.diffusion.DiffusionModel
|
||||
|
@ -33,8 +33,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.diffusion.diffusion.DiffusionModel
|
||||
|
@ -32,8 +32,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.common.gaussian.GaussianModel
|
||||
|
@ -32,8 +32,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-4
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.common.gaussian.GaussianModel
|
||||
|
@ -42,7 +42,7 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
|
@ -32,8 +32,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.common.gaussian.GaussianModel
|
||||
|
@ -32,8 +32,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.common.gaussian.GaussianModel
|
||||
|
@ -33,8 +33,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.common.gmm.GMMModel
|
||||
|
@ -33,8 +33,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.common.gmm.GMMModel
|
||||
|
@ -33,8 +33,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.diffusion.diffusion.DiffusionModel
|
||||
|
@ -43,7 +43,7 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
|
@ -33,8 +33,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.diffusion.diffusion.DiffusionModel
|
||||
|
@ -32,8 +32,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.common.gaussian.GaussianModel
|
||||
|
@ -42,7 +42,7 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
|
@ -32,8 +32,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.common.gaussian.GaussianModel
|
||||
|
@ -33,8 +33,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.common.gmm.GMMModel
|
||||
|
@ -33,8 +33,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.common.gmm.GMMModel
|
||||
|
@ -33,8 +33,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.diffusion.diffusion.DiffusionModel
|
||||
|
@ -43,7 +43,7 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
|
@ -33,8 +33,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.diffusion.diffusion.DiffusionModel
|
||||
|
@ -33,8 +33,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.diffusion.diffusion.DiffusionModel
|
||||
|
@ -33,8 +33,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.diffusion.diffusion.DiffusionModel
|
||||
|
@ -32,8 +32,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.common.gaussian.GaussianModel
|
||||
|
@ -32,8 +32,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-4
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.common.gaussian.GaussianModel
|
||||
|
@ -42,7 +42,7 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
|
@ -32,8 +32,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.common.gaussian.GaussianModel
|
||||
|
@ -32,8 +32,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.common.gaussian.GaussianModel
|
||||
|
@ -33,8 +33,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.common.gmm.GMMModel
|
||||
|
@ -33,8 +33,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.common.gmm.GMMModel
|
||||
|
@ -33,8 +33,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.diffusion.diffusion.DiffusionModel
|
||||
|
@ -43,7 +43,7 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
|
@ -33,8 +33,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.diffusion.diffusion.DiffusionModel
|
||||
|
@ -32,8 +32,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.common.gaussian.GaussianModel
|
||||
|
@ -42,7 +42,7 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
|
@ -32,8 +32,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.common.gaussian.GaussianModel
|
||||
|
@ -33,8 +33,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.common.gmm.GMMModel
|
||||
|
@ -33,8 +33,8 @@ train:
|
||||
warmup_steps: 100
|
||||
min_lr: 1e-5
|
||||
epoch_start_ema: 20
|
||||
update_ema_freq: 10
|
||||
save_model_freq: 1000
|
||||
update_ema_freq: 1
|
||||
save_model_freq: 500
|
||||
|
||||
model:
|
||||
_target_: model.common.gmm.GMMModel
|
||||
|
@ -61,7 +61,7 @@ train:
|
||||
first_cycle_steps: 1000
|
||||
warmup_steps: 10
|
||||
min_lr: 1e-4
|
||||
save_model_freq: 100000
|
||||
save_model_freq: 50000
|
||||
val_freq: 10000
|
||||
render:
|
||||
freq: 10000
|
||||
|
@ -61,7 +61,7 @@ train:
|
||||
first_cycle_steps: 1000
|
||||
warmup_steps: 10
|
||||
min_lr: 1e-4
|
||||
save_model_freq: 100000
|
||||
save_model_freq: 50000
|
||||
val_freq: 10000
|
||||
render:
|
||||
freq: 10000
|
||||
|
@ -61,7 +61,7 @@ train:
|
||||
first_cycle_steps: 1000
|
||||
warmup_steps: 10
|
||||
min_lr: 1e-4
|
||||
save_model_freq: 100000
|
||||
save_model_freq: 50000
|
||||
val_freq: 10000
|
||||
render:
|
||||
freq: 10000
|
||||
|
@ -61,7 +61,7 @@ train:
|
||||
first_cycle_steps: 1000
|
||||
warmup_steps: 10
|
||||
min_lr: 1e-4
|
||||
save_model_freq: 100000
|
||||
save_model_freq: 50000
|
||||
val_freq: 10000
|
||||
render:
|
||||
freq: 10000
|
||||
|
@ -284,6 +284,11 @@ def get_checkpoint_download_url(cfg):
|
||||
in path
|
||||
):
|
||||
return "https://drive.google.com/file/d/1Ngr-DNxoB9XNCZ2O-NF5p60NzmYlzmWG/view?usp=drive_link"
|
||||
elif (
|
||||
"lift_pre_diffusion_mlp_ta4_td20/2024-06-28_14-47-58/checkpoint/state_8000.pt"
|
||||
in path
|
||||
):
|
||||
return "https://drive.google.com/file/d/1IyXa6CEXO16mmCCHgNfFTnmvAhA3PVxQ/view?usp=drive_link"
|
||||
elif (
|
||||
"lift_pre_diffusion_mlp_img_ta4_td100/2024-07-30_22-24-35/checkpoint/state_2500.pt"
|
||||
in path
|
||||
@ -323,6 +328,11 @@ def get_checkpoint_download_url(cfg):
|
||||
in path
|
||||
):
|
||||
return "https://drive.google.com/file/d/1L1ZLD1u1Y1YJmRLGzScXbQ02wGS-_cWo/view?usp=drive_link"
|
||||
elif (
|
||||
"can_pre_diffusion_mlp_ta4_td20/2024-06-28_13-29-54/checkpoint/state_8000.pt"
|
||||
in path
|
||||
):
|
||||
return "https://drive.google.com/file/d/1_3-QcDrWCH6cPRPLuVnYQt25ymvBYHgn/view?usp=drive_link"
|
||||
elif (
|
||||
"can_pre_diffusion_mlp_img_ta4_td100/2024-07-30_22-23-55/checkpoint/state_5000.pt"
|
||||
in path
|
||||
|
Loading…
Reference in New Issue
Block a user