Removed the slow blob-len growth

This commit is contained in:
Dominik Moritz Roth 2021-09-22 11:50:03 +02:00
parent 8608a89909
commit ca964732c5

View File

@ -28,8 +28,8 @@ def train(model, seq_len=16*512): # 1KiB
state_h[0], state_c[0] = model.init_state(seq_len)
state_h[1], state_c[1] = model.init_state(seq_len)
blob[0], _ = shark.getSample(min(seq_len, 16*(epoch+1)), 0)
blob[1], _ = shark.getSample(min(seq_len, 16*(epoch+1)), 1)
blob[0], _ = shark.getSample(seq_len, 0)
blob[1], _ = shark.getSample(seq_len, 1)
optimizer.zero_grad()
for i in range(len(blob[0])):
for t in range(2):