Changed layout of NN; more options for playing against the ai

This commit is contained in:
Dominik Moritz Roth 2022-05-13 17:07:29 +02:00
parent 16c0913a3b
commit a939fd334c

View File

@ -86,8 +86,8 @@ class TTTState(State):
# sco -= 0.5
# return 1/sco
def getPriority(self, score, cascadeMem):
return self.generation + score*10 - cascadeMem*5 + 100
#def getPriority(self, score, cascadeMem):
# return -cascadeMem*1 + 100
def checkWin(self):
self.update_box_won()
@ -151,45 +151,47 @@ class Model(nn.Module):
self.smol = nn.Sequential(
nn.Conv2d(
in_channels=1,
out_channels=24,
out_channels=16,
kernel_size=(3,3),
stride=3,
padding=0,
),
nn.ReLU()
)
self.comb = nn.Sequential(
nn.Conv1d(
in_channels=24,
out_channels=8,
kernel_size=1,
stride=1,
padding=0,
),
nn.ReLU()
)
#self.comb = nn.Sequential(
# nn.Conv1d(
# in_channels=24,
# out_channels=8,
# kernel_size=1,
# stride=1,
# padding=0,
# ),
# nn.ReLU()
#)
self.out = nn.Sequential(
nn.Linear(9*8, 32),
#nn.Linear(9*8, 32),
#nn.ReLU(),
#nn.Linear(32, 8),
#nn.ReLU(),
nn.Linear(16*9, 12),
nn.ReLU(),
nn.Linear(32, 8),
nn.ReLU(),
nn.Linear(8, 1),
nn.Linear(12, 1),
nn.Sigmoid()
)
def forward(self, x):
x = torch.reshape(x, (1,9,9))
x = self.smol(x)
x = torch.reshape(x, (24,9))
x = self.comb(x)
#x = torch.reshape(x, (24,9))
#x = self.comb(x)
x = torch.reshape(x, (-1,))
y = self.out(x)
return y
def humanVsAi(train=True, remember=False):
def humanVsAi(train=True, remember=False, depth=3, bots=[0,1], noBg=False):
init = TTTState()
run = NeuralRuntime(init)
run.game([0,1], 3)
run.game(bots, depth, bg=not noBg)
if remember or train:
trainer = Trainer(init)
@ -209,7 +211,15 @@ def aiVsAiLoop():
trainer.train()
if __name__=='__main__':
if choose('?', ['Play Against AI','Let AI train'])=='Play Against AI':
options = ['Play Against AI','Play Against AI (AI begins)','Play Against AI (Fast Play)','Playground','Let AI train']
opt = choose('?', options)
if opt == options[0]:
humanVsAi()
elif opt == options[1]:
humanVsAi(bots[1,0])
elif opt == options[2]:
humanVsAi(depth=2,noBg=True)
elif opt == options[3]:
humanVsAi(bots=[None,None])
else:
aiVsAiLoop()