Refactored saving/loading of NN weights; changed priorization-mechanism of nodes
while exploring; added bg computation
This commit is contained in:
parent
5eaf83805f
commit
6967243ae2
BIN
brains/utt.vac
Normal file
BIN
brains/utt.vac
Normal file
Binary file not shown.
BIN
brains/uttt.pth.bak
Normal file
BIN
brains/uttt.pth.bak
Normal file
Binary file not shown.
@ -1,3 +1,5 @@
|
|||||||
|
import os
|
||||||
|
import io
|
||||||
import time
|
import time
|
||||||
import random
|
import random
|
||||||
import threading
|
import threading
|
||||||
@ -69,7 +71,7 @@ class State(ABC):
|
|||||||
# Lower prioritys get worked on first
|
# Lower prioritys get worked on first
|
||||||
# Higher generations should have higher priority
|
# Higher generations should have higher priority
|
||||||
# Higher cascadeMemory (more influence on higher-order-scores) should have lower priority
|
# Higher cascadeMemory (more influence on higher-order-scores) should have lower priority
|
||||||
return score + self.generation*0.5 - cascadeMemory*0.35
|
return -cascadeMemory + 100
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def checkWin(self):
|
def checkWin(self):
|
||||||
@ -418,7 +420,7 @@ class Runtime():
|
|||||||
return
|
return
|
||||||
raise Exception('No such action avaible...')
|
raise Exception('No such action avaible...')
|
||||||
|
|
||||||
def turn(self, bot=None, calcDepth=3):
|
def turn(self, bot=None, calcDepth=3, bg=True):
|
||||||
print(str(self.head))
|
print(str(self.head))
|
||||||
if bot==None:
|
if bot==None:
|
||||||
c = choose('Select action?', ['human', 'bot', 'undo', 'qlen'])
|
c = choose('Select action?', ['human', 'bot', 'undo', 'qlen'])
|
||||||
@ -445,13 +447,15 @@ class Runtime():
|
|||||||
action = self.head.askUserForAction()
|
action = self.head.askUserForAction()
|
||||||
self.performAction(action)
|
self.performAction(action)
|
||||||
|
|
||||||
def game(self, bots=None, calcDepth=7):
|
def game(self, bots=None, calcDepth=7, bg=True):
|
||||||
|
if bg:
|
||||||
self.spawnWorker()
|
self.spawnWorker()
|
||||||
if bots==None:
|
if bots==None:
|
||||||
bots = [None]*self.head.playersNum
|
bots = [None]*self.head.playersNum
|
||||||
while self.head.getWinner()==None:
|
while self.head.getWinner()==None:
|
||||||
self.turn(bots[self.head.curPlayer], calcDepth)
|
self.turn(bots[self.head.curPlayer], calcDepth, bg=True)
|
||||||
print(['O','X','No one'][head.getWinner()] + ' won!')
|
print(['O','X','No one'][self.head.getWinner()] + ' won!')
|
||||||
|
if bg:
|
||||||
self.killWorker()
|
self.killWorker()
|
||||||
|
|
||||||
class NeuralRuntime(Runtime):
|
class NeuralRuntime(Runtime):
|
||||||
@ -570,37 +574,58 @@ class Trainer(Runtime):
|
|||||||
lLoss = loss_sum
|
lLoss = loss_sum
|
||||||
return loss_sum
|
return loss_sum
|
||||||
|
|
||||||
def main(self, model=None, gens=1024, startGen=12):
|
def main(self, model=None, gens=1024, startGen=0):
|
||||||
newModel = False
|
newModel = False
|
||||||
if model==None:
|
if model==None:
|
||||||
|
print('[!] No brain found. Creating new one...')
|
||||||
newModel = True
|
newModel = True
|
||||||
model = self.rootNode.state.getModel()
|
model = self.rootNode.state.getModel()
|
||||||
self.universe.scoreProvider = ['neural','naive'][newModel]
|
self.universe.scoreProvider = ['neural','naive'][newModel]
|
||||||
|
model.train()
|
||||||
for gen in range(startGen, startGen+gens):
|
for gen in range(startGen, startGen+gens):
|
||||||
print('[#####] Gen '+str(gen)+' training:')
|
print('[#####] Gen '+str(gen)+' training:')
|
||||||
loss = self.trainModel(model, calcDepth=min(5,3+int(gen/16)), exacity=int(gen/3+1))
|
loss = self.trainModel(model, calcDepth=min(4,3+int(gen/16)), exacity=int(gen/3+1))
|
||||||
print('[L] '+str(loss))
|
print('[L] '+str(loss))
|
||||||
self.universe.scoreProvider = 'neural'
|
self.universe.scoreProvider = 'neural'
|
||||||
self.saveModel(model)
|
self.saveModel(model, gen)
|
||||||
|
|
||||||
def saveModel(self, model):
|
def saveModel(self, model, gen):
|
||||||
torch.save(model.state_dict(), 'brains/uttt.pth')
|
dat = model.state_dict()
|
||||||
|
with open(self.getModelFileName(), 'wb') as f:
|
||||||
|
pickle.dump((gen, dat), f)
|
||||||
|
|
||||||
|
def loadModelState(self, model):
|
||||||
|
with open(self.getModelFileName(), 'rb') as f:
|
||||||
|
gen, dat = pickle.load(f)
|
||||||
|
model.load_state_dict(dat)
|
||||||
|
model.eval()
|
||||||
|
return gen
|
||||||
|
|
||||||
|
def loadModel(self):
|
||||||
|
model = self.rootNode.state.getModel()
|
||||||
|
gen = self.loadModelState(model)
|
||||||
|
return model, gen
|
||||||
|
|
||||||
def train(self):
|
def train(self):
|
||||||
model = self.rootNode.state.getModel()
|
if os.path.exists(self.getModelFileName()):
|
||||||
model.load_state_dict(torch.load('brains/uttt.pth'))
|
model, gen = self.loadModel()
|
||||||
model.eval()
|
self.main(model, startGen=gen+1)
|
||||||
self.main(model, startGen=0)
|
else:
|
||||||
|
self.main()
|
||||||
|
|
||||||
|
def getModelFileName(self):
|
||||||
|
return 'brains/utt.vac'
|
||||||
|
|
||||||
def trainFromTerm(self, term):
|
def trainFromTerm(self, term):
|
||||||
model = self.rootNode.state.getModel()
|
model = self.rootNode.state.getModel()
|
||||||
model.load_state_dict(torch.load('brains/uttt.pth'))
|
model.load_state_dict(torch.load('brains/uttt.vac'))
|
||||||
model.eval()
|
model.eval()
|
||||||
self.universe.scoreProvider = 'neural'
|
self.universe.scoreProvider = 'neural'
|
||||||
self.trainModel(model, calcDepth=4, exacity=10, term=term)
|
self.trainModel(model, calcDepth=4, exacity=10, term=term)
|
||||||
self.saveModel(model)
|
self.saveModel(model)
|
||||||
|
|
||||||
def saveToMemoryBank(self, term):
|
def saveToMemoryBank(self, term):
|
||||||
|
return
|
||||||
with open('memoryBank/uttt/'+datetime.datetime.now().strftime('%Y-%m-%d_%H:%M:%S')+'_'+str(int(random.random()*99999))+'.vdm', 'wb') as f:
|
with open('memoryBank/uttt/'+datetime.datetime.now().strftime('%Y-%m-%d_%H:%M:%S')+'_'+str(int(random.random()*99999))+'.vdm', 'wb') as f:
|
||||||
pickle.dump(term, f)
|
pickle.dump(term, f)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user