1:0 against kenria Booyaa

This commit is contained in:
Dominik Moritz Roth 2022-04-15 11:18:34 +02:00
parent 6cc2d84519
commit d164a59e31
3 changed files with 44 additions and 27 deletions

Binary file not shown.

View File

@ -141,40 +141,47 @@ class Model(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.smolChan = 12
self.compChan = 7
self.smol = nn.Sequential( self.smol = nn.Sequential(
nn.Conv2d( nn.Conv2d(
in_channels=1, in_channels=1,
out_channels=self.smolChan, out_channels=24,
kernel_size=(3,3), kernel_size=(3,3),
stride=3, stride=3,
padding=0, padding=0,
), ),
nn.ReLU() nn.ReLU()
) )
self.big = nn.Sequential( self.comb = nn.Sequential(
nn.Linear(self.smolChan*9, self.compChan), nn.Conv1d(
#nn.ReLU(), in_channels=24,
#nn.Linear(self.compChan, 1), out_channels=8,
kernel_size=1,
stride=1,
padding=0,
),
nn.ReLU()
)
self.out = nn.Sequential(
nn.Linear(9*8, 32),
nn.ReLU(), nn.ReLU(),
nn.Linear(self.compChan, 3), nn.Linear(32, 8),
nn.ReLU(), nn.ReLU(),
nn.Linear(3, 1), nn.Linear(8, 1),
nn.Sigmoid() nn.Sigmoid()
) )
def forward(self, x): def forward(self, x):
x = torch.reshape(x, (1,9,9)) x = torch.reshape(x, (1,9,9))
x = self.smol(x) x = self.smol(x)
x = torch.reshape(x, (self.smolChan*9,)) x = torch.reshape(x, (24,9))
y = self.big(x) x = self.comb(x)
x = torch.reshape(x, (-1,))
y = self.out(x)
return y return y
if __name__=="__main__": if __name__=="__main__":
run = NeuralRuntime(TTTState()) run = NeuralRuntime(TTTState())
run.game(None, 4) run.game([0,1], 4)
trainer = Trainer(TTTState()) #trainer = Trainer(TTTState())
trainer.train() #trainer.train()

View File

@ -11,6 +11,7 @@ from threading import Event
from queue import PriorityQueue, Empty from queue import PriorityQueue, Empty
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any from typing import Any
import random
class Action(): class Action():
# Should hold the data representing an action # Should hold the data representing an action
@ -388,6 +389,7 @@ class Runtime():
def __init__(self, initState): def __init__(self, initState):
universe = QueueingUniverse() universe = QueueingUniverse()
self.head = Node(initState, universe = universe) self.head = Node(initState, universe = universe)
_ = self.head.childs
universe.newOpen(self.head) universe.newOpen(self.head)
def spawnWorker(self): def spawnWorker(self):
@ -460,9 +462,9 @@ class Trainer(Runtime):
self.rootNode = Node(initState, universe = self.universe) self.rootNode = Node(initState, universe = self.universe)
self.terminal = None self.terminal = None
def buildDatasetFromModel(self, model, depth=4, refining=True): def buildDatasetFromModel(self, model, depth=4, refining=True, exacity=5):
print('[*] Building Timeline') print('[*] Building Timeline')
term = self.linearPlay(model, calcDepth=depth) term = self.linearPlay(model, calcDepth=depth, exacity=exacity)
if refining: if refining:
print('[*] Refining Timeline') print('[*] Refining Timeline')
self.fanOut(term, depth=depth+1) self.fanOut(term, depth=depth+1)
@ -475,7 +477,7 @@ class Trainer(Runtime):
head = head.parent head = head.parent
head.forceStrong(depth) head.forceStrong(depth)
def linearPlay(self, model, calcDepth=7, verbose=True): def linearPlay(self, model, calcDepth=7, exacity=5, verbose=True):
head = self.rootNode head = self.rootNode
self.universe.model = model self.universe.model = model
while head.getWinner()==None: while head.getWinner()==None:
@ -490,7 +492,10 @@ class Trainer(Runtime):
for c in head.childs: for c in head.childs:
opts.append((c, c.getStrongFor(head.curPlayer))) opts.append((c, c.getStrongFor(head.curPlayer)))
opts.sort(key=lambda x: x[1]) opts.sort(key=lambda x: x[1])
ind = int(pow(random.random(),5)*(len(opts)-1)) if exacity >= 10:
ind = 0
else:
ind = int(pow(random.random(),exacity)*(len(opts)-1))
head = opts[ind][0] head = opts[ind][0]
print('') print('')
return head return head
@ -499,16 +504,20 @@ class Trainer(Runtime):
head = term head = term
while True: while True:
yield head yield head
if len(head.childs):
yield random.choice(head.childs)
if head.parent == None: if head.parent == None:
return return
head = head.parent head = head.parent
def trainModel(self, model, lr=0.01, cut=0.01, calcDepth=4): def trainModel(self, model, lr=0.00005, cut=0.01, calcDepth=4, exacity=5):
loss_func = nn.MSELoss() loss_func = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr) optimizer = optim.Adam(model.parameters(), lr)
term = self.buildDatasetFromModel(model, depth=calcDepth) term = self.buildDatasetFromModel(model, depth=calcDepth, exacity=exacity)
print('[*] Conditioning Brain...')
for r in range(64): for r in range(64):
loss_sum = 0 loss_sum = 0
lLoss = 0
zeroLen = 0 zeroLen = 0
for i, node in enumerate(self.timelineIter(term)): for i, node in enumerate(self.timelineIter(term)):
for p in range(self.rootNode.playersNum): for p in range(self.rootNode.playersNum):
@ -524,19 +533,20 @@ class Trainer(Runtime):
zeroLen+=1 zeroLen+=1
if zeroLen == 5: if zeroLen == 5:
break break
print(loss_sum/i) #print(loss_sum/i)
if loss_sum/i < cut: if r > 16 and (loss_sum/i < cut or lLoss == loss_sum):
return return
lLoss = loss_sum
def main(self, model=None, gens=64): def main(self, model=None, gens=1024, startGen=12):
newModel = False newModel = False
if model==None: if model==None:
newModel = True newModel = True
model = self.rootNode.state.getModel() model = self.rootNode.state.getModel()
self.universe.scoreProvider = ['neural','naive'][newModel] self.universe.scoreProvider = ['neural','naive'][newModel]
for gen in range(gens): for gen in range(startGen, startGen+gens):
print('[#####] Gen '+str(gen)+' training:') print('[#####] Gen '+str(gen)+' training:')
self.trainModel(model, calcDepth=3) self.trainModel(model, calcDepth=min(5,3+int(gen/16)), exacity=int(gen/3+1))
self.universe.scoreProvider = 'neural' self.universe.scoreProvider = 'neural'
torch.save(model.state_dict(), 'brains/uttt.pth') torch.save(model.state_dict(), 'brains/uttt.pth')
@ -544,4 +554,4 @@ class Trainer(Runtime):
model = self.rootNode.state.getModel() model = self.rootNode.state.getModel()
model.load_state_dict(torch.load('brains/uttt.pth')) model.load_state_dict(torch.load('brains/uttt.pth'))
model.eval() model.eval()
self.main(model) self.main(model, startGen=0)