New and improved stuff

This commit is contained in:
Dominik Moritz Roth 2022-04-16 15:11:59 +02:00
parent 8353dc13d0
commit f15a707f02
2 changed files with 22 additions and 13 deletions

Binary file not shown.

View File

@ -451,7 +451,7 @@ class Runtime():
bots = [None]*self.head.playersNum bots = [None]*self.head.playersNum
while self.head.getWinner()==None: while self.head.getWinner()==None:
self.turn(bots[self.head.curPlayer], calcDepth) self.turn(bots[self.head.curPlayer], calcDepth)
print(str(self.head.getWinner()) + ' won!') print(['O','X','No one'][head.getWinner()] + ' won!')
self.killWorker() self.killWorker()
class NeuralRuntime(Runtime): class NeuralRuntime(Runtime):
@ -473,7 +473,7 @@ class Trainer(Runtime):
self.rootNode = self.head self.rootNode = self.head
self.terminal = None self.terminal = None
def buildDatasetFromModel(self, model, depth=4, refining=True, fanOut=[16,16,8,8,6,6,5,4], uncertainSec=15, exacity=5): def buildDatasetFromModel(self, model, depth=4, refining=True, fanOut=[5,5,5,5,4,4,4,4], uncertainSec=15, exacity=5):
print('[*] Building Timeline') print('[*] Building Timeline')
term = self.linearPlay(model, calcDepth=depth, exacity=exacity) term = self.linearPlay(model, calcDepth=depth, exacity=exacity)
if refining: if refining:
@ -482,13 +482,16 @@ class Trainer(Runtime):
for d in fanOut: for d in fanOut:
cur = cur.parent cur = cur.parent
cur.forceStrong(d) cur.forceStrong(d)
print('.', end='', flush=True)
print('')
print('[*] Refining Timeline (exploring uncertain regions)') print('[*] Refining Timeline (exploring uncertain regions)')
self.timelineExpandUncertain(term, uncertainSec) self.timelineExpandUncertain(term, uncertainSec)
return term return term
def linearPlay(self, model, calcDepth=7, exacity=5, verbose=True): def linearPlay(self, model, calcDepth=7, exacity=5, verbose=False, firstNRandom=2):
head = self.rootNode head = self.rootNode
self.universe.model = model self.universe.model = model
self.spawnWorker()
while head.getWinner()==None: while head.getWinner()==None:
if verbose: if verbose:
print(head) print(head)
@ -500,15 +503,20 @@ class Trainer(Runtime):
break break
for c in head.childs: for c in head.childs:
opts.append((c, c.getStrongFor(head.curPlayer))) opts.append((c, c.getStrongFor(head.curPlayer)))
opts.sort(key=lambda x: x[1]) if firstNRandom:
if exacity >= 10: firstNRandom-=1
ind = 0 ind = int(random.random()*len(opts))
else: else:
ind = int(pow(random.random(),exacity)*(len(opts)-1)) opts.sort(key=lambda x: x[1])
if exacity >= 10:
ind = 0
else:
ind = int(pow(random.random(),exacity)*(len(opts)-1))
head = opts[ind][0] head = opts[ind][0]
self.killWorker()
if verbose: if verbose:
print(head) print(head)
print(' => '+['O','X','No one '][self.head.getWinner()] + ' won!') print(' => '+['O','X','No one'][head.getWinner()] + ' won!')
return head return head
def timelineIter(self, term): def timelineIter(self, term):
@ -525,16 +533,19 @@ class Trainer(Runtime):
self.rootNode.universe.clearPQ() self.rootNode.universe.clearPQ()
self.rootNode.universe.activateEdge(self.rootNode) self.rootNode.universe.activateEdge(self.rootNode)
self.spawnWorker() self.spawnWorker()
time.sleep(secs) for s in range(secs):
time.sleep(1)
print('.', end='', flush=True)
self.rootNode.universe.clearPQ() self.rootNode.universe.clearPQ()
self.killWorker() self.killWorker()
print('')
def trainModel(self, model, lr=0.00005, cut=0.01, calcDepth=4, exacity=5, term=None): def trainModel(self, model, lr=0.00005, cut=0.01, calcDepth=4, exacity=5, term=None):
loss_func = nn.MSELoss() loss_func = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr) optimizer = optim.Adam(model.parameters(), lr)
if term==None: if term==None:
term = self.buildDatasetFromModel(model, depth=calcDepth, exacity=exacity) term = self.buildDatasetFromModel(model, depth=calcDepth, exacity=exacity)
print('[*] Conditioning Brain...') print('[*] Conditioning Brain')
for r in range(64): for r in range(64):
loss_sum = 0 loss_sum = 0
lLoss = 0 lLoss = 0
@ -542,10 +553,8 @@ class Trainer(Runtime):
for i, node in enumerate(self.timelineIter(term)): for i, node in enumerate(self.timelineIter(term)):
for p in range(self.rootNode.playersNum): for p in range(self.rootNode.playersNum):
inp = node.state.getTensor(player=p) inp = node.state.getTensor(player=p)
gol = torch.tensor(node.getStrongFor(p), dtype=torch.float) gol = torch.tensor([node.getStrongFor(p)], dtype=torch.float)
out = model(inp) out = model(inp)
if not out:
continue
loss = loss_func(out, gol) loss = loss_func(out, gol)
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()