diff --git a/ultimatetictactoe.py b/ultimatetictactoe.py index 7e327d0..571c131 100644 --- a/ultimatetictactoe.py +++ b/ultimatetictactoe.py @@ -180,8 +180,14 @@ class Model(nn.Module): return y if __name__=="__main__": - run = NeuralRuntime(TTTState()) + init = TTTState() + run = NeuralRuntime(init) run.game([0,1], 4) - #trainer = Trainer(TTTState()) - #trainer.train() + + print("[!] Your knowledge will be assimilated!!!") + trainer = Trainer(init) + trainer.train() + trainer.trainFromTerm(run.head) + print('[!] I have become smart. Destroyer of human Ultimate-TicTacToe players!') + trainer.saveToMemoryBank(term) diff --git a/vacuumDecay.py b/vacuumDecay.py index 2b20a2c..a14cc75 100644 --- a/vacuumDecay.py +++ b/vacuumDecay.py @@ -442,7 +442,7 @@ class Runtime(): bots = [None]*self.head.playersNum while self.head.getWinner()==None: self.turn(bots[self.head.curPlayer], calcDepth) - print(self.head.getWinner() + ' won!') + print(str(self.head.getWinner()) + ' won!') self.killWorker() class NeuralRuntime(Runtime): @@ -510,10 +510,11 @@ class Trainer(Runtime): return head = head.parent - def trainModel(self, model, lr=0.00005, cut=0.01, calcDepth=4, exacity=5): + def trainModel(self, model, lr=0.00005, cut=0.01, calcDepth=4, exacity=5, term=None): loss_func = nn.MSELoss() optimizer = optim.Adam(model.parameters(), lr) - term = self.buildDatasetFromModel(model, depth=calcDepth, exacity=exacity) + if term==None: + term = self.buildDatasetFromModel(model, depth=calcDepth, exacity=exacity) print('[*] Conditioning Brain...') for r in range(64): loss_sum = 0 @@ -555,3 +556,15 @@ class Trainer(Runtime): model.load_state_dict(torch.load('brains/uttt.pth')) model.eval() self.main(model, startGen=0) + + def trainFromTerm(self, term): + model = self.rootNode.state.getModel() + model.load_state_dict(torch.load('brains/uttt.pth')) + model.eval() + self.universe.scoreProvider = 'neural' + self.trainModel(model, calcDepth=4, exacity=10, term=term) + + def saveToMemoryBank(self, term): + with open('memoryBank/uttt/'+datetime.datetime.now().strftime('%Y-%m-%d_%H:%M:%S')+'_'+str(int(random.random()*99999))+'.vdm', 'wb') as f: + pickel.dump(term, f) +