Smol changes on parameters
This commit is contained in:
		
							parent
							
								
									d7c5d57d93
								
							
						
					
					
						commit
						4f250f61c3
					
				
							
								
								
									
										19
									
								
								train.py
									
									
									
									
									
								
							
							
						
						
									
										19
									
								
								train.py
									
									
									
									
									
								
							@ -9,7 +9,7 @@ import math
 | 
				
			|||||||
import shark
 | 
					import shark
 | 
				
			||||||
from model import Model
 | 
					from model import Model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def train(model, seq_len=16*128): # 0.25KiB
 | 
					def train(model, seq_len=16*256): # 0.5KiB
 | 
				
			||||||
    tid = str(int(random.random()*99999)).zfill(5)
 | 
					    tid = str(int(random.random()*99999)).zfill(5)
 | 
				
			||||||
    print("[i] I am "+str(tid))
 | 
					    print("[i] I am "+str(tid))
 | 
				
			||||||
    ltLoss = 0.75
 | 
					    ltLoss = 0.75
 | 
				
			||||||
@ -17,12 +17,14 @@ def train(model, seq_len=16*128): # 0.25KiB
 | 
				
			|||||||
    model.train()
 | 
					    model.train()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    criterion = nn.BCELoss()
 | 
					    criterion = nn.BCELoss()
 | 
				
			||||||
    optimizer = optim.Adam(model.parameters(), lr=0.001)
 | 
					    optimizer = optim.Adam(model.parameters(), lr=0.0001)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    state_h = [None,None]
 | 
					    state_h = [None,None]
 | 
				
			||||||
    state_c = [None,None]
 | 
					    state_c = [None,None]
 | 
				
			||||||
    blob = [None,None]
 | 
					    blob = [None,None]
 | 
				
			||||||
    correct = [None,None]
 | 
					    correct = [None,None]
 | 
				
			||||||
 | 
					    err = [None,None]
 | 
				
			||||||
 | 
					    ltErr = 0.5
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    for epoch in range(1024):
 | 
					    for epoch in range(1024):
 | 
				
			||||||
        state_h[0], state_c[0] = model.init_state(seq_len)
 | 
					        state_h[0], state_c[0] = model.init_state(seq_len)
 | 
				
			||||||
@ -44,17 +46,16 @@ def train(model, seq_len=16*128): # 0.25KiB
 | 
				
			|||||||
                optimizer.step()
 | 
					                optimizer.step()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                correct[t] = round(y_pred.item()) == t
 | 
					                correct[t] = round(y_pred.item()) == t
 | 
				
			||||||
 | 
					                err[t] = abs(t - y_pred.item())
 | 
				
			||||||
            ltLoss = ltLoss*0.9 + 0.1*loss.item()
 | 
					            ltLoss = ltLoss*0.9 + 0.1*loss.item()
 | 
				
			||||||
 | 
					            ltErr = ltErr*0.99 + (err[0] + err[1])*0.005
 | 
				
			||||||
        lltLoss = lltLoss*0.9 + 0.1*ltLoss
 | 
					        lltLoss = lltLoss*0.9 + 0.1*ltLoss
 | 
				
			||||||
        print({ 'epoch': epoch, 'loss': loss.item(), 'lltLoss': lltLoss, 'ok0': correct[0], 'ok1': correct[1], 'succ': correct[0] and correct[1], 'acc': str(int(max(0, 1-math.sqrt(lltLoss))*100))+"%" })
 | 
					        print({ 'epoch': epoch, 'loss': loss.item(), 'lltLoss': lltLoss, 'ok0': correct[0], 'ok1': correct[1], 'succ': correct[0] and correct[1], 'acc': str(int(100-(err[0]+err[1])*50))+"%" })
 | 
				
			||||||
        if epoch % 8 == 0:
 | 
					        torch.save(model.state_dict(), 'model_savepoints/'+tid+'_'+str(epoch)+'.n')
 | 
				
			||||||
            torch.save(model.state_dict(), 'model_savepoints/'+tid+'_'+str(epoch)+'.n')
 | 
					    if 0.45 < ltErr < 0.55:
 | 
				
			||||||
    if lltLoss > 0.49:
 | 
					 | 
				
			||||||
        print("[~] My emperor! I've failed! A BARREL ROLL!")
 | 
					        print("[~] My emperor! I've failed! A BARREL ROLL!")
 | 
				
			||||||
    elif lltLoss < 0.45:
 | 
					 | 
				
			||||||
        print("[~] Booyaaa!!!!")
 | 
					 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        print("[~] Meh...")
 | 
					        print("[~] Booyaaa!!!!")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
model = Model()
 | 
					model = Model()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
		Reference in New Issue
	
	Block a user