Removed the slow blob-len growth

This commit is contained in:
Dominik Moritz Roth 2021-09-22 11:50:03 +02:00
parent 8608a89909
commit ca964732c5

View File

@ -28,8 +28,8 @@ def train(model, seq_len=16*512): # 1KiB
state_h[0], state_c[0] = model.init_state(seq_len) state_h[0], state_c[0] = model.init_state(seq_len)
state_h[1], state_c[1] = model.init_state(seq_len) state_h[1], state_c[1] = model.init_state(seq_len)
blob[0], _ = shark.getSample(min(seq_len, 16*(epoch+1)), 0) blob[0], _ = shark.getSample(seq_len, 0)
blob[1], _ = shark.getSample(min(seq_len, 16*(epoch+1)), 1) blob[1], _ = shark.getSample(seq_len, 1)
optimizer.zero_grad() optimizer.zero_grad()
for i in range(len(blob[0])): for i in range(len(blob[0])):
for t in range(2): for t in range(2):