Changed layout of NN; more options for playing against the ai
This commit is contained in:
		
							parent
							
								
									16c0913a3b
								
							
						
					
					
						commit
						a939fd334c
					
				@ -86,8 +86,8 @@ class TTTState(State):
 | 
				
			|||||||
    #            sco -= 0.5
 | 
					    #            sco -= 0.5
 | 
				
			||||||
    #    return 1/sco
 | 
					    #    return 1/sco
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def getPriority(self, score, cascadeMem):
 | 
					    #def getPriority(self, score, cascadeMem):
 | 
				
			||||||
        return self.generation + score*10 - cascadeMem*5 + 100
 | 
					    #    return -cascadeMem*1 + 100
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def checkWin(self):
 | 
					    def checkWin(self):
 | 
				
			||||||
        self.update_box_won()
 | 
					        self.update_box_won()
 | 
				
			||||||
@ -151,45 +151,47 @@ class Model(nn.Module):
 | 
				
			|||||||
        self.smol = nn.Sequential(
 | 
					        self.smol = nn.Sequential(
 | 
				
			||||||
            nn.Conv2d(
 | 
					            nn.Conv2d(
 | 
				
			||||||
                in_channels=1,
 | 
					                in_channels=1,
 | 
				
			||||||
                out_channels=24,
 | 
					                out_channels=16,
 | 
				
			||||||
                kernel_size=(3,3),
 | 
					                kernel_size=(3,3),
 | 
				
			||||||
                stride=3,
 | 
					                stride=3,
 | 
				
			||||||
                padding=0,
 | 
					                padding=0,
 | 
				
			||||||
            ),
 | 
					            ),
 | 
				
			||||||
            nn.ReLU()
 | 
					            nn.ReLU()
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        self.comb = nn.Sequential(
 | 
					        #self.comb = nn.Sequential(
 | 
				
			||||||
            nn.Conv1d(
 | 
					        #    nn.Conv1d(
 | 
				
			||||||
                in_channels=24,
 | 
					        #        in_channels=24,
 | 
				
			||||||
                out_channels=8,
 | 
					        #        out_channels=8,
 | 
				
			||||||
                kernel_size=1,
 | 
					        #        kernel_size=1,
 | 
				
			||||||
                stride=1,
 | 
					        #        stride=1,
 | 
				
			||||||
                padding=0,
 | 
					        #        padding=0,
 | 
				
			||||||
            ),
 | 
					        #    ),
 | 
				
			||||||
            nn.ReLU()
 | 
					        #    nn.ReLU()
 | 
				
			||||||
        )
 | 
					        #)
 | 
				
			||||||
        self.out = nn.Sequential(
 | 
					        self.out = nn.Sequential(
 | 
				
			||||||
            nn.Linear(9*8, 32),
 | 
					            #nn.Linear(9*8, 32),
 | 
				
			||||||
 | 
					            #nn.ReLU(),
 | 
				
			||||||
 | 
					            #nn.Linear(32, 8),
 | 
				
			||||||
 | 
					            #nn.ReLU(),
 | 
				
			||||||
 | 
					            nn.Linear(16*9, 12),
 | 
				
			||||||
            nn.ReLU(),
 | 
					            nn.ReLU(),
 | 
				
			||||||
            nn.Linear(32, 8),
 | 
					            nn.Linear(12, 1),
 | 
				
			||||||
            nn.ReLU(),
 | 
					 | 
				
			||||||
            nn.Linear(8, 1),
 | 
					 | 
				
			||||||
            nn.Sigmoid()
 | 
					            nn.Sigmoid()
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def forward(self, x):
 | 
					    def forward(self, x):
 | 
				
			||||||
        x = torch.reshape(x, (1,9,9))
 | 
					        x = torch.reshape(x, (1,9,9))
 | 
				
			||||||
        x = self.smol(x)
 | 
					        x = self.smol(x)
 | 
				
			||||||
        x = torch.reshape(x, (24,9))
 | 
					        #x = torch.reshape(x, (24,9))
 | 
				
			||||||
        x = self.comb(x)
 | 
					        #x = self.comb(x)
 | 
				
			||||||
        x = torch.reshape(x, (-1,))
 | 
					        x = torch.reshape(x, (-1,))
 | 
				
			||||||
        y = self.out(x)
 | 
					        y = self.out(x)
 | 
				
			||||||
        return y
 | 
					        return y
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def humanVsAi(train=True, remember=False):
 | 
					def humanVsAi(train=True, remember=False, depth=3, bots=[0,1], noBg=False):
 | 
				
			||||||
    init = TTTState()
 | 
					    init = TTTState()
 | 
				
			||||||
    run = NeuralRuntime(init)
 | 
					    run = NeuralRuntime(init)
 | 
				
			||||||
    run.game([0,1], 3)
 | 
					    run.game(bots, depth, bg=not noBg)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if remember or train:
 | 
					    if remember or train:
 | 
				
			||||||
        trainer = Trainer(init)
 | 
					        trainer = Trainer(init)
 | 
				
			||||||
@ -209,7 +211,15 @@ def aiVsAiLoop():
 | 
				
			|||||||
    trainer.train()
 | 
					    trainer.train()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__=='__main__':
 | 
					if __name__=='__main__':
 | 
				
			||||||
    if choose('?', ['Play Against AI','Let AI train'])=='Play Against AI':
 | 
					    options = ['Play Against AI','Play Against AI (AI begins)','Play Against AI (Fast Play)','Playground','Let AI train']
 | 
				
			||||||
 | 
					    opt = choose('?', options)
 | 
				
			||||||
 | 
					    if opt == options[0]:
 | 
				
			||||||
        humanVsAi()
 | 
					        humanVsAi()
 | 
				
			||||||
 | 
					    elif opt == options[1]:
 | 
				
			||||||
 | 
					        humanVsAi(bots[1,0])
 | 
				
			||||||
 | 
					    elif opt == options[2]:
 | 
				
			||||||
 | 
					        humanVsAi(depth=2,noBg=True)
 | 
				
			||||||
 | 
					    elif opt == options[3]:
 | 
				
			||||||
 | 
					        humanVsAi(bots=[None,None])
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        aiVsAiLoop()
 | 
					        aiVsAiLoop()
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
		Reference in New Issue
	
	Block a user