Increased layer count of model and refactored model into its own file
This commit is contained in:
		
							parent
							
								
									1b4b6f1a2f
								
							
						
					
					
						commit
						483201c693
					
				| @ -8,6 +8,7 @@ import numpy as np | ||||
| import random | ||||
| 
 | ||||
| import shark | ||||
| from model import Model | ||||
| 
 | ||||
| bs = int(256/8) | ||||
| 
 | ||||
|  | ||||
							
								
								
									
										26
									
								
								model.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								model.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,26 @@ | ||||
| import torch | ||||
| from torch import nn | ||||
| from torch import nn, optim | ||||
| from torch.utils.data import DataLoader | ||||
| 
 | ||||
| class Model(nn.Module): | ||||
|     def __init__(self): | ||||
|         super(Model, self).__init__() | ||||
|         self.lstm = nn.LSTM( | ||||
|             input_size=8, | ||||
|             hidden_size=16, | ||||
|             num_layers=5, | ||||
|             dropout=0.01, | ||||
|         ) | ||||
|         self.fc = nn.Linear(16, 1) | ||||
|         self.out = nn.Sigmoid() | ||||
| 
 | ||||
|     def forward(self, x, prev_state): | ||||
|         output, state = self.lstm(x, prev_state) | ||||
|         logits = self.fc(output) | ||||
|         val = self.out(logits) | ||||
|         return val, state | ||||
| 
 | ||||
|     def init_state(self, sequence_length): | ||||
|         return (torch.zeros(5, 1, 16), | ||||
|                 torch.zeros(5, 1, 16)) | ||||
							
								
								
									
										24
									
								
								train.py
									
									
									
									
									
								
							
							
						
						
									
										24
									
								
								train.py
									
									
									
									
									
								
							| @ -7,29 +7,7 @@ import random | ||||
| import math | ||||
| 
 | ||||
| import shark | ||||
| 
 | ||||
| class Model(nn.Module): | ||||
|     def __init__(self): | ||||
|         super(Model, self).__init__() | ||||
|         self.lstm = nn.LSTM( | ||||
|             input_size=8, | ||||
|             hidden_size=16, | ||||
|             num_layers=3, | ||||
|             dropout=0.1, | ||||
|         ) | ||||
|         self.fc = nn.Linear(16, 1) | ||||
|         self.out = nn.Sigmoid() | ||||
| 
 | ||||
|     def forward(self, x, prev_state): | ||||
|         output, state = self.lstm(x, prev_state) | ||||
|         logits = self.fc(output) | ||||
|         val = self.out(logits) | ||||
|         #print(str(logits.item())+" > "+str(val.item())) | ||||
|         return val, state | ||||
| 
 | ||||
|     def init_state(self, sequence_length): | ||||
|         return (torch.zeros(3, 1, 16), | ||||
|                 torch.zeros(3, 1, 16)) | ||||
| from model import Model | ||||
| 
 | ||||
| def train(model, seq_len=16*64): | ||||
|     tid = str(int(random.random()*99999)).zfill(5) | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user