diff --git a/train.py b/train.py index ee18ca0..69e819e 100644 --- a/train.py +++ b/train.py @@ -28,8 +28,8 @@ def train(model, seq_len=16*512): # 1KiB state_h[0], state_c[0] = model.init_state(seq_len) state_h[1], state_c[1] = model.init_state(seq_len) - blob[0], _ = shark.getSample(min(seq_len, 16*(epoch+1)), 0) - blob[1], _ = shark.getSample(min(seq_len, 16*(epoch+1)), 1) + blob[0], _ = shark.getSample(seq_len, 0) + blob[1], _ = shark.getSample(seq_len, 1) optimizer.zero_grad() for i in range(len(blob[0])): for t in range(2):