Removed the slow blob-len growth
This commit is contained in:
parent
8608a89909
commit
ca964732c5
4
train.py
4
train.py
@ -28,8 +28,8 @@ def train(model, seq_len=16*512): # 1KiB
|
|||||||
state_h[0], state_c[0] = model.init_state(seq_len)
|
state_h[0], state_c[0] = model.init_state(seq_len)
|
||||||
state_h[1], state_c[1] = model.init_state(seq_len)
|
state_h[1], state_c[1] = model.init_state(seq_len)
|
||||||
|
|
||||||
blob[0], _ = shark.getSample(min(seq_len, 16*(epoch+1)), 0)
|
blob[0], _ = shark.getSample(seq_len, 0)
|
||||||
blob[1], _ = shark.getSample(min(seq_len, 16*(epoch+1)), 1)
|
blob[1], _ = shark.getSample(seq_len, 1)
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
for i in range(len(blob[0])):
|
for i in range(len(blob[0])):
|
||||||
for t in range(2):
|
for t in range(2):
|
||||||
|
Loading…
Reference in New Issue
Block a user