diff --git a/ultimatetictactoe.py b/ultimatetictactoe.py index 1ddad80..025b539 100644 --- a/ultimatetictactoe.py +++ b/ultimatetictactoe.py @@ -86,8 +86,8 @@ class TTTState(State): # sco -= 0.5 # return 1/sco - def getPriority(self, score, cascadeMem): - return self.generation + score*10 - cascadeMem*5 + 100 + #def getPriority(self, score, cascadeMem): + # return -cascadeMem*1 + 100 def checkWin(self): self.update_box_won() @@ -151,45 +151,47 @@ class Model(nn.Module): self.smol = nn.Sequential( nn.Conv2d( in_channels=1, - out_channels=24, + out_channels=16, kernel_size=(3,3), stride=3, padding=0, ), nn.ReLU() ) - self.comb = nn.Sequential( - nn.Conv1d( - in_channels=24, - out_channels=8, - kernel_size=1, - stride=1, - padding=0, - ), - nn.ReLU() - ) + #self.comb = nn.Sequential( + # nn.Conv1d( + # in_channels=24, + # out_channels=8, + # kernel_size=1, + # stride=1, + # padding=0, + # ), + # nn.ReLU() + #) self.out = nn.Sequential( - nn.Linear(9*8, 32), + #nn.Linear(9*8, 32), + #nn.ReLU(), + #nn.Linear(32, 8), + #nn.ReLU(), + nn.Linear(16*9, 12), nn.ReLU(), - nn.Linear(32, 8), - nn.ReLU(), - nn.Linear(8, 1), + nn.Linear(12, 1), nn.Sigmoid() ) def forward(self, x): x = torch.reshape(x, (1,9,9)) x = self.smol(x) - x = torch.reshape(x, (24,9)) - x = self.comb(x) + #x = torch.reshape(x, (24,9)) + #x = self.comb(x) x = torch.reshape(x, (-1,)) y = self.out(x) return y -def humanVsAi(train=True, remember=False): +def humanVsAi(train=True, remember=False, depth=3, bots=[0,1], noBg=False): init = TTTState() run = NeuralRuntime(init) - run.game([0,1], 3) + run.game(bots, depth, bg=not noBg) if remember or train: trainer = Trainer(init) @@ -209,7 +211,15 @@ def aiVsAiLoop(): trainer.train() if __name__=='__main__': - if choose('?', ['Play Against AI','Let AI train'])=='Play Against AI': + options = ['Play Against AI','Play Against AI (AI begins)','Play Against AI (Fast Play)','Playground','Let AI train'] + opt = choose('?', options) + if opt == options[0]: humanVsAi() + elif opt == options[1]: + humanVsAi(bots[1,0]) + elif opt == options[2]: + humanVsAi(depth=2,noBg=True) + elif opt == options[3]: + humanVsAi(bots=[None,None]) else: aiVsAiLoop()