New: MemoryBank
This commit is contained in:
parent
d164a59e31
commit
4912fbc0a4
@ -180,8 +180,14 @@ class Model(nn.Module):
|
|||||||
return y
|
return y
|
||||||
|
|
||||||
if __name__=="__main__":
|
if __name__=="__main__":
|
||||||
run = NeuralRuntime(TTTState())
|
init = TTTState()
|
||||||
|
run = NeuralRuntime(init)
|
||||||
run.game([0,1], 4)
|
run.game([0,1], 4)
|
||||||
|
|
||||||
#trainer = Trainer(TTTState())
|
|
||||||
#trainer.train()
|
print("[!] Your knowledge will be assimilated!!!")
|
||||||
|
trainer = Trainer(init)
|
||||||
|
trainer.train()
|
||||||
|
trainer.trainFromTerm(run.head)
|
||||||
|
print('[!] I have become smart. Destroyer of human Ultimate-TicTacToe players!')
|
||||||
|
trainer.saveToMemoryBank(term)
|
||||||
|
@ -442,7 +442,7 @@ class Runtime():
|
|||||||
bots = [None]*self.head.playersNum
|
bots = [None]*self.head.playersNum
|
||||||
while self.head.getWinner()==None:
|
while self.head.getWinner()==None:
|
||||||
self.turn(bots[self.head.curPlayer], calcDepth)
|
self.turn(bots[self.head.curPlayer], calcDepth)
|
||||||
print(self.head.getWinner() + ' won!')
|
print(str(self.head.getWinner()) + ' won!')
|
||||||
self.killWorker()
|
self.killWorker()
|
||||||
|
|
||||||
class NeuralRuntime(Runtime):
|
class NeuralRuntime(Runtime):
|
||||||
@ -510,9 +510,10 @@ class Trainer(Runtime):
|
|||||||
return
|
return
|
||||||
head = head.parent
|
head = head.parent
|
||||||
|
|
||||||
def trainModel(self, model, lr=0.00005, cut=0.01, calcDepth=4, exacity=5):
|
def trainModel(self, model, lr=0.00005, cut=0.01, calcDepth=4, exacity=5, term=None):
|
||||||
loss_func = nn.MSELoss()
|
loss_func = nn.MSELoss()
|
||||||
optimizer = optim.Adam(model.parameters(), lr)
|
optimizer = optim.Adam(model.parameters(), lr)
|
||||||
|
if term==None:
|
||||||
term = self.buildDatasetFromModel(model, depth=calcDepth, exacity=exacity)
|
term = self.buildDatasetFromModel(model, depth=calcDepth, exacity=exacity)
|
||||||
print('[*] Conditioning Brain...')
|
print('[*] Conditioning Brain...')
|
||||||
for r in range(64):
|
for r in range(64):
|
||||||
@ -555,3 +556,15 @@ class Trainer(Runtime):
|
|||||||
model.load_state_dict(torch.load('brains/uttt.pth'))
|
model.load_state_dict(torch.load('brains/uttt.pth'))
|
||||||
model.eval()
|
model.eval()
|
||||||
self.main(model, startGen=0)
|
self.main(model, startGen=0)
|
||||||
|
|
||||||
|
def trainFromTerm(self, term):
|
||||||
|
model = self.rootNode.state.getModel()
|
||||||
|
model.load_state_dict(torch.load('brains/uttt.pth'))
|
||||||
|
model.eval()
|
||||||
|
self.universe.scoreProvider = 'neural'
|
||||||
|
self.trainModel(model, calcDepth=4, exacity=10, term=term)
|
||||||
|
|
||||||
|
def saveToMemoryBank(self, term):
|
||||||
|
with open('memoryBank/uttt/'+datetime.datetime.now().strftime('%Y-%m-%d_%H:%M:%S')+'_'+str(int(random.random()*99999))+'.vdm', 'wb') as f:
|
||||||
|
pickel.dump(term, f)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user