Smashed bugs
This commit is contained in:
parent
c4ff7832ef
commit
bc64d679bf
BIN
brains/uttt.pth
BIN
brains/uttt.pth
Binary file not shown.
@ -93,9 +93,14 @@ class TTTState(State):
|
|||||||
self.update_box_won()
|
self.update_box_won()
|
||||||
game_won = self.check_small_box(self.box_won)
|
game_won = self.check_small_box(self.box_won)
|
||||||
if game_won == '.':
|
if game_won == '.':
|
||||||
|
if self.checkDraw():
|
||||||
|
return -1
|
||||||
return None
|
return None
|
||||||
return game_won == 'X'
|
return game_won == 'X'
|
||||||
|
|
||||||
|
def checkDraw(self):
|
||||||
|
return len(self.getAvaibleActions())==0
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
state = self.board
|
state = self.board
|
||||||
acts = list(self.getAvaibleActions())
|
acts = list(self.getAvaibleActions())
|
||||||
|
@ -187,8 +187,6 @@ class Node():
|
|||||||
|
|
||||||
def _expand(self):
|
def _expand(self):
|
||||||
self._childs = []
|
self._childs = []
|
||||||
if self.getWinner()!=None:
|
|
||||||
return
|
|
||||||
actions = self.state.getAvaibleActions()
|
actions = self.state.getAvaibleActions()
|
||||||
for action in actions:
|
for action in actions:
|
||||||
newNode = Node(self.state.mutate(action), self.universe, self, action)
|
newNode = Node(self.state.mutate(action), self.universe, self, action)
|
||||||
@ -287,12 +285,12 @@ class Node():
|
|||||||
self._calcScore(p)
|
self._calcScore(p)
|
||||||
|
|
||||||
def _calcScore(self, player):
|
def _calcScore(self, player):
|
||||||
winner = self.getWinner()
|
winner = self._getWinner()
|
||||||
if winner!=None:
|
if winner!=None:
|
||||||
if winner==-1:
|
if winner==player:
|
||||||
self._scores[player] = 2/3 # draw
|
|
||||||
elif winner==player:
|
|
||||||
self._scores[player] = 0.0
|
self._scores[player] = 0.0
|
||||||
|
elif winner==-1:
|
||||||
|
self._scores[player] = 2/3
|
||||||
else:
|
else:
|
||||||
self._scores[player] = 1.0
|
self._scores[player] = 1.0
|
||||||
return
|
return
|
||||||
@ -321,17 +319,20 @@ class Node():
|
|||||||
def curPlayer(self):
|
def curPlayer(self):
|
||||||
return self.state.curPlayer
|
return self.state.curPlayer
|
||||||
|
|
||||||
|
def _getWinner(self):
|
||||||
|
return self.state.checkWin()
|
||||||
|
|
||||||
def getWinner(self):
|
def getWinner(self):
|
||||||
if len(self.childs)==0:
|
if len(self.childs)==0:
|
||||||
return -1
|
return -1
|
||||||
return self.state.checkWin()
|
return self._getWinner()
|
||||||
|
|
||||||
def _activateEdge(self, dist=0):
|
def _activateEdge(self, dist=0):
|
||||||
if not self.strongScoresAvaible():
|
if not self.strongScoresAvaible():
|
||||||
self.universe.newOpen(self)
|
self.universe.newOpen(self)
|
||||||
else:
|
else:
|
||||||
for c in self.childs:
|
for c in self.childs:
|
||||||
if c._cascadeMemory > 0.001*dist or random.random()<0.01:
|
if c._cascadeMemory > 0.001*(dist-2) or random.random()<0.01:
|
||||||
c._activateEdge(dist=dist+1)
|
c._activateEdge(dist=dist+1)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
@ -463,6 +464,7 @@ class NeuralRuntime(Runtime):
|
|||||||
class Trainer(Runtime):
|
class Trainer(Runtime):
|
||||||
def __init__(self, initState):
|
def __init__(self, initState):
|
||||||
self.universe = Universe()
|
self.universe = Universe()
|
||||||
|
#self.universe = QueueingUniverse()
|
||||||
self.rootNode = Node(initState, universe = self.universe)
|
self.rootNode = Node(initState, universe = self.universe)
|
||||||
self.terminal = None
|
self.terminal = None
|
||||||
|
|
||||||
@ -470,10 +472,12 @@ class Trainer(Runtime):
|
|||||||
print('[*] Building Timeline')
|
print('[*] Building Timeline')
|
||||||
term = self.linearPlay(model, calcDepth=depth, exacity=exacity)
|
term = self.linearPlay(model, calcDepth=depth, exacity=exacity)
|
||||||
if refining:
|
if refining:
|
||||||
print('[*] Refining Timeline')
|
print('[*] Refining Timeline (exploring alternative endings)')
|
||||||
self.fanOut(term, depth=depth+2)
|
self.fanOut(term, depth=depth+2)
|
||||||
self.fanOut(term.parent, depth=depth+2)
|
self.fanOut(term.parent, depth=depth+2)
|
||||||
self.fanOut(term.parent.parent, depth=depth+2)
|
self.fanOut(term.parent.parent, depth=depth+2)
|
||||||
|
#print('[*] Refining Timeline (exploring uncertain regions)')
|
||||||
|
#self.timelineExpandUncertain(term, 20)
|
||||||
return term
|
return term
|
||||||
|
|
||||||
def fanOut(self, head, depth=4):
|
def fanOut(self, head, depth=4):
|
||||||
@ -514,6 +518,15 @@ class Trainer(Runtime):
|
|||||||
return
|
return
|
||||||
head = head.parent
|
head = head.parent
|
||||||
|
|
||||||
|
def timelineExpandUncertain(self, term, secs):
|
||||||
|
return
|
||||||
|
self.rootNode.universe.clearPQ()
|
||||||
|
self.rootNode.universe.activateEdge(rootNode)
|
||||||
|
self.spawnWorker()
|
||||||
|
time.sleep(secs)
|
||||||
|
self.rootNode.universe.clearPQ()
|
||||||
|
self.killWorker()
|
||||||
|
|
||||||
def trainModel(self, model, lr=0.00005, cut=0.01, calcDepth=4, exacity=5, term=None):
|
def trainModel(self, model, lr=0.00005, cut=0.01, calcDepth=4, exacity=5, term=None):
|
||||||
loss_func = nn.MSELoss()
|
loss_func = nn.MSELoss()
|
||||||
optimizer = optim.Adam(model.parameters(), lr)
|
optimizer = optim.Adam(model.parameters(), lr)
|
||||||
|
Loading…
Reference in New Issue
Block a user