This commit is contained in:
Dominik Moritz Roth 2022-04-15 14:30:52 +02:00
parent e002375651
commit 5d808d77c9
2 changed files with 24 additions and 9 deletions

View File

@ -179,18 +179,29 @@ class Model(nn.Module):
y = self.out(x)
return y
if __name__=="__main__":
def humanVsAi(train=True, remember=True):
init = TTTState()
run = NeuralRuntime(init)
run.game([0,1], 4)
print("[!] Your knowledge will be assimilated!!! Please stand by....")
trainer = Trainer(init)
trainer.train()
trainer.trainFromTerm(run.head)
if train:
print("[!] Your knowledge will be assimilated!!! Please stand by....")
trainer = Trainer(init)
trainer.trainFromTerm(run.head)
print('[!] I have become smart! Destroyer of human Ultimate-TicTacToe players! (Neuristic update completed)')
trainer.saveToMemoryBank(term)
print('[!] Your cognitive and strategic destinctiveness was added to my own! (Game inserted into memoryBank)')
if remember:
trainer.saveToMemoryBank(term)
print('[!] Your cognitive and strategic destinctiveness was added to my own! (Game inserted into memoryBank)')
print('[!] This marks the beginning of the end of humankind!')
print('[i] Thanks for playing! Goodbye...')
def aiVsAiLoop():
init = TTTState()
trainer = Trainer(init)
trainer.train()
if __name__=='__main__':
if choose('?', ['Play Against AI','Let AI train'])=='Play Against AI':
humanVsAi()
else:
aiVsAiLoop()

View File

@ -549,7 +549,10 @@ class Trainer(Runtime):
print('[#####] Gen '+str(gen)+' training:')
self.trainModel(model, calcDepth=min(5,3+int(gen/16)), exacity=int(gen/3+1))
self.universe.scoreProvider = 'neural'
torch.save(model.state_dict(), 'brains/uttt.pth')
self.saveModel(model)
def saveModel(self, model):
torch.save(model.state_dict(), 'brains/uttt.pth')
def train(self):
model = self.rootNode.state.getModel()
@ -563,6 +566,7 @@ class Trainer(Runtime):
model.eval()
self.universe.scoreProvider = 'neural'
self.trainModel(model, calcDepth=4, exacity=10, term=term)
self.saveModel(model)
def saveToMemoryBank(self, term):
with open('memoryBank/uttt/'+datetime.datetime.now().strftime('%Y-%m-%d_%H:%M:%S')+'_'+str(int(random.random()*99999))+'.vdm', 'wb') as f: