From ca964732c599e067da0ad46674c3f72cb05db86a Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Wed, 22 Sep 2021 11:50:03 +0200 Subject: [PATCH] Removed the slow blob-len growth --- train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index ee18ca0..69e819e 100644 --- a/train.py +++ b/train.py @@ -28,8 +28,8 @@ def train(model, seq_len=16*512): # 1KiB state_h[0], state_c[0] = model.init_state(seq_len) state_h[1], state_c[1] = model.init_state(seq_len) - blob[0], _ = shark.getSample(min(seq_len, 16*(epoch+1)), 0) - blob[1], _ = shark.getSample(min(seq_len, 16*(epoch+1)), 1) + blob[0], _ = shark.getSample(seq_len, 0) + blob[1], _ = shark.getSample(seq_len, 1) optimizer.zero_grad() for i in range(len(blob[0])): for t in range(2):