Blub
This commit is contained in:
parent
a46557a635
commit
5ba277a2aa
23
tictactoe.py
23
tictactoe.py
@ -2,24 +2,23 @@ from vacuumDecay import *
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
class TTTState(State):
|
class TTTState(State):
|
||||||
def __init__(self, turn=0, generation=0, playersNum=2, board=None):
|
def __init__(self, curPlayer=0, generation=0, playersNum=2, board=None):
|
||||||
if type(board) == type(None):
|
if type(board) == type(None):
|
||||||
board = np.array([None]*9)
|
board = np.array([None]*9)
|
||||||
self.turn = turn
|
self.curPlayer = curPlayer
|
||||||
self.generation = generation
|
self.generation = generation
|
||||||
self.playersNum = playersNum
|
self.playersNum = playersNum
|
||||||
self.board = board
|
self.board = board
|
||||||
self.score = self.getScore()
|
|
||||||
|
|
||||||
def mutate(self, action):
|
def mutate(self, action):
|
||||||
newBoard = np.copy(self.board)
|
newBoard = np.copy(self.board)
|
||||||
newBoard[action.data] = self.turn
|
newBoard[action.data] = self.curPlayer
|
||||||
return TTTState(turn=(self.turn+1)%self.playersNum, playersNum=self.playersNum, board=newBoard)
|
return TTTState(curPlayer=(self.curPlayer+1)%self.playersNum, playersNum=self.playersNum, board=newBoard)
|
||||||
|
|
||||||
def getAvaibleActions(self):
|
def getAvaibleActions(self):
|
||||||
for i in range(9):
|
for i in range(9):
|
||||||
if self.board[i]==None:
|
if self.board[i]==None:
|
||||||
yield Action(self.turn, i)
|
yield Action(self.curPlayer, i)
|
||||||
|
|
||||||
def checkWin(self):
|
def checkWin(self):
|
||||||
s = self.board
|
s = self.board
|
||||||
@ -49,13 +48,13 @@ class TTTState(State):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def getModel():
|
def getModel():
|
||||||
return torch.nn.Sequential(
|
return torch.nn.Sequential(
|
||||||
torch.nn.Linear(10, 10)
|
torch.nn.Linear(10, 10),
|
||||||
torch.nn.ReLu()
|
torch.nn.ReLu(),
|
||||||
torch.nn.Linear(10, 3)
|
torch.nn.Linear(10, 3),
|
||||||
torch.nn.Sigmoid()
|
torch.nn.Sigmoid(),
|
||||||
torch.nn.Linear(3,1)
|
torch.nn.Linear(3,1)
|
||||||
)
|
)
|
||||||
|
|
||||||
if __name__=="__main__":
|
if __name__=="__main__":
|
||||||
vd = VacuumDecay(TTTState())
|
run = Runtime(TTTState())
|
||||||
vd.weakPlay()
|
run.game()
|
||||||
|
498
vacuumDecay.py
498
vacuumDecay.py
@ -28,33 +28,37 @@ class Action():
|
|||||||
# should start with < and end with >
|
# should start with < and end with >
|
||||||
return "<P"+str(self.player)+"-"+str(self.data)+">"
|
return "<P"+str(self.player)+"-"+str(self.data)+">"
|
||||||
|
|
||||||
class NaiveUniverse():
|
class Universe():
|
||||||
def __init__(self):
|
def newOpen(self, node):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def merge(self, branch):
|
def merge(self, node):
|
||||||
return branch
|
return node
|
||||||
|
|
||||||
class BranchUniverse():
|
def clearPQ(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def iter(self):
|
||||||
|
return []
|
||||||
|
|
||||||
|
def activateEdge(self, head):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class QueueingUniverse(Universe):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.branches = {}
|
self.pq = []
|
||||||
|
|
||||||
def merge(self, branch):
|
def newOpen(self, node):
|
||||||
tensor = branch.node.state.getTensor()
|
heapq.headpush(self.pq, (node.priority, node))
|
||||||
match = self.branches.get(tensor)
|
|
||||||
if match:
|
|
||||||
return match
|
|
||||||
else:
|
|
||||||
self.branches[tensor] = branch
|
|
||||||
|
|
||||||
class Branch():
|
def clearPQ(self):
|
||||||
def __new__(self, universe, preState, action): # fancy!
|
self.pq = []
|
||||||
self.preState = preState
|
|
||||||
self.action = action
|
def iter(self):
|
||||||
postState = preState.mutate(action)
|
yield heapq.heappop(self.pq)
|
||||||
self.node = Node(postState, universe=universe,
|
|
||||||
parent=preState, lastAction=action)
|
def activateEdge(self, head):
|
||||||
return universe.merge(self)
|
head._activateEdge()
|
||||||
|
|
||||||
|
|
||||||
class State(ABC):
|
class State(ABC):
|
||||||
@ -65,17 +69,16 @@ class State(ABC):
|
|||||||
# The calculated score should be 0 when won; higher when in a worse state; highest for loosing
|
# The calculated score should be 0 when won; higher when in a worse state; highest for loosing
|
||||||
# getPriority is used for prioritising certain Nodes / States when expanding / walking the tree
|
# getPriority is used for prioritising certain Nodes / States when expanding / walking the tree
|
||||||
|
|
||||||
def __init__(self, turn=0, generation=0, playersNum=2):
|
def __init__(self, curPlayer=0, generation=0, playersNum=2):
|
||||||
self.turn = turn
|
self.curPlayer = curPlayer
|
||||||
self.generation = generation
|
self.generation = generation
|
||||||
self.playersNum = playersNum
|
self.playersNum = playersNum
|
||||||
self.score = self.getScore()
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def mutate(self, action):
|
def mutate(self, action):
|
||||||
# Returns a new state with supplied action performed
|
# Returns a new state with supplied action performed
|
||||||
# self should not be changed
|
# self should not be changed
|
||||||
return State(turn=(self.turn+1) % self.playersNum, generation=self.generation+1, playersNum=self.playersNum)
|
return State(curPlayer=(self.curPlayer+1) % self.playersNum, generation=self.generation+1, playersNum=self.playersNum)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def getAvaibleActions(self):
|
def getAvaibleActions(self):
|
||||||
@ -87,8 +90,8 @@ class State(ABC):
|
|||||||
# Used for ordering the priority queue
|
# Used for ordering the priority queue
|
||||||
# Priority should not change for the same root
|
# Priority should not change for the same root
|
||||||
# Lower prioritys get worked on first
|
# Lower prioritys get worked on first
|
||||||
# Higher generations should have slightly higher priority
|
# Higher generations should have higher priority
|
||||||
return score + self.generation*0.1
|
return score + self.generation*0.5
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def checkWin(self):
|
def checkWin(self):
|
||||||
@ -98,12 +101,12 @@ class State(ABC):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
# improveMe
|
# improveMe
|
||||||
def getScore(self):
|
def getScoreFor(self, player):
|
||||||
# 0 <= score <= 1; should return close to zero when we are winning
|
# 0 <= score <= 1; should return close to zero when we are winning
|
||||||
w = self.checkWin()
|
w = self.checkWin()
|
||||||
if w == None:
|
if w == None:
|
||||||
return 0.5
|
return 0.5
|
||||||
if w == 0:
|
if w == player:
|
||||||
return 0
|
return 0
|
||||||
if w == -1:
|
if w == -1:
|
||||||
return 0.9
|
return 0.9
|
||||||
@ -115,7 +118,7 @@ class State(ABC):
|
|||||||
return "[#]"
|
return "[#]"
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def getTensor(self):
|
def getTensor(self, phase='default'):
|
||||||
return torch.tensor([0])
|
return torch.tensor([0])
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -123,166 +126,131 @@ class State(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def getScoreNeural(self):
|
def getScoreNeural(self):
|
||||||
pass
|
|
||||||
return self.model(self.getTensor())
|
return self.model(self.getTensor())
|
||||||
|
|
||||||
|
|
||||||
class Node():
|
class Node():
|
||||||
def __init__(self, state, universe=None, parent=None, lastAction=None, playersNum=2):
|
def __init__(self, state, universe=None, parent=None, lastAction=None):
|
||||||
self.state = state
|
self.state = state
|
||||||
if not universe:
|
if universe==None:
|
||||||
universe = NaiveUniverse()
|
universe = Universe()
|
||||||
# TODO: Maybe add self to new BranchUniverse?
|
|
||||||
self.universe = universe
|
self.universe = universe
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.lastAction = lastAction
|
self.lastAction = lastAction
|
||||||
self.playersNum = playersNum
|
|
||||||
|
|
||||||
self.childs = None
|
self._childs = None
|
||||||
self.score = state.getScore()
|
self._scores = [None]*self.state.playersNum
|
||||||
self.done = Event()
|
self._strongs = [None]*self.state.playersNum
|
||||||
self.threads = []
|
self._alive = True
|
||||||
self.walking = False
|
|
||||||
self.alive = True
|
|
||||||
|
|
||||||
def expand(self, shuffle=True):
|
def kill(self):
|
||||||
|
self._alive = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def childs(self):
|
||||||
|
if self._childs == None:
|
||||||
|
self._expand()
|
||||||
|
return self._childs
|
||||||
|
|
||||||
|
def _expand(self):
|
||||||
|
self._childs = []
|
||||||
actions = self.state.getAvaibleActions()
|
actions = self.state.getAvaibleActions()
|
||||||
if self.childs != None:
|
|
||||||
return True
|
|
||||||
self.childs = []
|
|
||||||
for action in actions:
|
for action in actions:
|
||||||
self.childs.append(Branch(self.universe, self.state, action))
|
newNode = Node(self.state.mutate(action), self.universe, self, action)
|
||||||
if self.childs == []:
|
self._childs.append(self.universe.merge(newNode))
|
||||||
return False
|
|
||||||
if shuffle:
|
@property
|
||||||
random.shuffle(self.childs)
|
def strongs(self):
|
||||||
|
return self._strongs
|
||||||
|
|
||||||
|
def _pullStrong(self): # Currently Expecti-Max
|
||||||
|
strongs = [None]*self.playersNum
|
||||||
|
for p in range(self.playersNum):
|
||||||
|
cp = self.state.curPlayer
|
||||||
|
if cp == p: # P owns the turn; controlls outcome
|
||||||
|
best = 10000000
|
||||||
|
for c in self.childs:
|
||||||
|
if c._strongs[cp] < best:
|
||||||
|
best = c._strongs[p]
|
||||||
|
strongs[p] = best
|
||||||
|
else:
|
||||||
|
scos = [(c._strongs[cp], c._strongs[p]) for c in self.childs]
|
||||||
|
scos.sort(key=lambda x: x[0])
|
||||||
|
betterHalf = scos[:max(3,int(len(scos)/2))]
|
||||||
|
myScores = [bh[1] for bh in betterHalf]
|
||||||
|
strongs[p] = sum(myScores)/len(myScores)
|
||||||
|
update = False
|
||||||
|
for s in range(self.playersNum):
|
||||||
|
if strongs[s] != self._strongs[s]:
|
||||||
|
update = True
|
||||||
|
break
|
||||||
|
self._strongs = strongs
|
||||||
|
if update:
|
||||||
|
self.parent._pullStrong()
|
||||||
|
|
||||||
|
def forceStrong(self, depth=3):
|
||||||
|
if depth==0:
|
||||||
|
self.strongDecay()
|
||||||
|
else:
|
||||||
|
for c in self.childs:
|
||||||
|
c.forceStrong(depth-1)
|
||||||
|
|
||||||
|
def strongDecay(self):
|
||||||
|
if self._strongs == [None]*self.playersNum:
|
||||||
|
if not self.scoresAvaible():
|
||||||
|
self._calcScores()
|
||||||
|
self._strongs = self._scores
|
||||||
|
self.parent._pullStrong()
|
||||||
|
|
||||||
|
def getSelfScore(self):
|
||||||
|
return self.getScoreFor(self.curPlayer)
|
||||||
|
|
||||||
|
def getScoreFor(self, player):
|
||||||
|
if self._scores[player] == None:
|
||||||
|
self._calcScore(player)
|
||||||
|
return self._scores[player]
|
||||||
|
|
||||||
|
def scoreAvaible(self, player):
|
||||||
|
return self._scores[player] != None
|
||||||
|
|
||||||
|
def scoresAvaible(self):
|
||||||
|
for p in self._scores:
|
||||||
|
if p==None:
|
||||||
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _perform(self, action):
|
def _calcScores(self):
|
||||||
if self.childs == None:
|
for p in range(self.state.playersNum):
|
||||||
raise PerformOnUnexpandedNodeException()
|
self._calcScore(p)
|
||||||
elif self.childs == []:
|
|
||||||
raise PerformOnTerminalNodeException()
|
|
||||||
for child in self.childs:
|
|
||||||
if child.node.lastAction == action:
|
|
||||||
self.endWalk()
|
|
||||||
return child
|
|
||||||
raise IllegalActionException()
|
|
||||||
|
|
||||||
def performBot(self):
|
def _calcScore(self, player):
|
||||||
if self.state.turn != 0:
|
self._scores[player] = self.state.getScoreFor(player)
|
||||||
raise NotBotsTurnException()
|
|
||||||
if self.childs == None:
|
|
||||||
raise PerformOnUnexpandedNodeException()
|
|
||||||
if self.childs == []:
|
|
||||||
raise PerformOnTerminalNodeException()
|
|
||||||
if self.walking:
|
|
||||||
self.endWalk()
|
|
||||||
bChild = self.childs[0]
|
|
||||||
for child in self.childs[1:]:
|
|
||||||
if not child:
|
|
||||||
print(self)
|
|
||||||
if child.node.score <= bChild.node.score:
|
|
||||||
bChild = child
|
|
||||||
return bChild
|
|
||||||
|
|
||||||
def performPlayer(self, action):
|
@property
|
||||||
if self.state.turn == 0:
|
def priority(self):
|
||||||
raise NotPlayersTurnException()
|
return self.state.getPriority(self.score)
|
||||||
return self._perform(action)
|
|
||||||
|
|
||||||
def getAvaibleActions(self):
|
@property
|
||||||
return self.state.getAvaibleActions()
|
def playersNum(self):
|
||||||
|
return self.state.playersNum
|
||||||
|
|
||||||
def getLastAction(self):
|
@property
|
||||||
return self.lastAction
|
def avaibleActions(self):
|
||||||
|
r = []
|
||||||
|
for c in self.childs:
|
||||||
|
r.append(c.lastAction)
|
||||||
|
return r
|
||||||
|
|
||||||
def beginWalk(self, threadNum=1):
|
@property
|
||||||
if self.walking:
|
def curPlayer(self):
|
||||||
raise Exception("Already Walking")
|
return self.state.curPlayer
|
||||||
self.walking = True
|
|
||||||
self.queue = PriorityQueue()
|
|
||||||
self.done.clear()
|
|
||||||
self.expand()
|
|
||||||
self._activateEdge()
|
|
||||||
for i in range(threadNum):
|
|
||||||
t = threading.Thread(target=self._worker)
|
|
||||||
t.start()
|
|
||||||
self.threads.append(t)
|
|
||||||
|
|
||||||
def endWalk(self):
|
def _activateEdge(self):
|
||||||
if not self.walking:
|
if not self.strongScoresAvaible():
|
||||||
raise Exception("Not Walking")
|
self.universe.newOpen(self)
|
||||||
self.done.set()
|
else:
|
||||||
for t in self.threads:
|
for c in self.childs:
|
||||||
t.join()
|
c._activateEdge()
|
||||||
self.walking = False
|
|
||||||
|
|
||||||
def walkUntilDone(self):
|
|
||||||
if not self.walking:
|
|
||||||
self.beginWalk()
|
|
||||||
for t in self.threads:
|
|
||||||
t.join()
|
|
||||||
self.done.set()
|
|
||||||
|
|
||||||
def syncWalk(self, time, threads=16):
|
|
||||||
self.beginWalk(threadNum=threadNum)
|
|
||||||
time.sleep(time)
|
|
||||||
self.endWalk()
|
|
||||||
|
|
||||||
def _worker(self):
|
|
||||||
while not self.done.is_set():
|
|
||||||
try:
|
|
||||||
node = self.queue.get_nowait()
|
|
||||||
except Empty:
|
|
||||||
continue
|
|
||||||
if node.alive:
|
|
||||||
if node.expand():
|
|
||||||
node._updateScore()
|
|
||||||
if self.done.is_set():
|
|
||||||
queque.task_done()
|
|
||||||
break
|
|
||||||
if node.state.checkWin == None:
|
|
||||||
for c in node.childs:
|
|
||||||
self.queue.put(c.node)
|
|
||||||
self.queue.task_done()
|
|
||||||
|
|
||||||
def _activateEdge(self, node=None):
|
|
||||||
if node == None:
|
|
||||||
node = self
|
|
||||||
if node.childs == None:
|
|
||||||
self.queue.put(node)
|
|
||||||
elif node.alive:
|
|
||||||
for c in node.childs:
|
|
||||||
self._activateEdge(node=c.node)
|
|
||||||
|
|
||||||
def __lt__(self, other):
|
|
||||||
# Used for ordering the priority queue
|
|
||||||
return self.state.getPriority(self.score) < other.state.getPriority(self.score)
|
|
||||||
|
|
||||||
# improveMe
|
|
||||||
def _calcAggScore(self):
|
|
||||||
if self.childs != None and self.childs != []:
|
|
||||||
scores = [c.node.score for c in self.childs]
|
|
||||||
if self.state.turn == 0:
|
|
||||||
self.score = min(scores)
|
|
||||||
elif self.playersNum == 2:
|
|
||||||
self.score = max(scores)
|
|
||||||
else:
|
|
||||||
# Note: This might be tweaked
|
|
||||||
self.score = (max(scores) + sum(scores)/len(scores)) / 2
|
|
||||||
|
|
||||||
def _updateScore(self):
|
|
||||||
oldScore = self.score
|
|
||||||
self._calcAggScore()
|
|
||||||
if self.score != oldScore:
|
|
||||||
self._pushScore()
|
|
||||||
|
|
||||||
def _pushScore(self):
|
|
||||||
if self.parent != None:
|
|
||||||
self.parent._updateScore()
|
|
||||||
elif self.score == 0:
|
|
||||||
self.done.set()
|
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
s = []
|
s = []
|
||||||
@ -290,143 +258,71 @@ class Node():
|
|||||||
s.append("[ {ROOT} ]")
|
s.append("[ {ROOT} ]")
|
||||||
else:
|
else:
|
||||||
s.append("[ -> "+str(self.lastAction)+" ]")
|
s.append("[ -> "+str(self.lastAction)+" ]")
|
||||||
s.append("[ turn: "+str(self.state.turn)+" ]")
|
s.append("[ turn: "+str(self.state.curPlayer)+" ]")
|
||||||
s.append(str(self.state))
|
s.append(str(self.state))
|
||||||
s.append("[ score: "+str(self.score)+" ]")
|
s.append("[ score: "+str(self.getSelfScore())+" ]")
|
||||||
return '\n'.join(s)
|
return '\n'.join(s)
|
||||||
|
|
||||||
|
def choose(txt, options):
|
||||||
|
while True:
|
||||||
|
print('[*] '+txt)
|
||||||
|
for num,opt in enumerate(options):
|
||||||
|
print('['+str(num+1)+'] ' + str(opt))
|
||||||
|
inp = input('[> ')
|
||||||
|
try:
|
||||||
|
n = int(inp)
|
||||||
|
if n in range(1,len(options)+1):
|
||||||
|
return options[n-1]
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
for opt in options:
|
||||||
|
if inp==str(opt):
|
||||||
|
return opt
|
||||||
|
if len(inp)==1:
|
||||||
|
for opt in options:
|
||||||
|
if inp==str(opt)[0]:
|
||||||
|
return opt
|
||||||
|
print('[!] Invalid Input.')
|
||||||
|
|
||||||
class WeakSolver():
|
class Runtime():
|
||||||
def __init__(self, state):
|
def __init__(self, initState):
|
||||||
self.node = Node(state)
|
self.head = Node(initState)
|
||||||
|
|
||||||
def play(self):
|
def performAction(self, action):
|
||||||
while self.node.state.checkWin() == None:
|
for c in self.head.childs:
|
||||||
self.step()
|
if action == c.lastAction:
|
||||||
print(self.node)
|
self.head.universe.clearPQ()
|
||||||
print("[*] " + str(self.node.state.checkWin()) + " won!")
|
self.head.kill()
|
||||||
if self.node.walking:
|
self.head = c
|
||||||
self.node.endWalk()
|
self.head.universe.activateEdge(self.head)
|
||||||
|
return
|
||||||
|
raise Exception('No such action avaible...')
|
||||||
|
|
||||||
def step(self):
|
def turn(self, bot=None):
|
||||||
if self.node.state.turn == 0:
|
print(str(self.head))
|
||||||
self.botStep()
|
if bot==None:
|
||||||
|
c = choose('?', ['human', 'bot', 'undo'])
|
||||||
|
if c=='undo':
|
||||||
|
self.head = self.head.parent
|
||||||
|
return
|
||||||
|
bot = c=='bot'
|
||||||
|
if bot:
|
||||||
|
opts = []
|
||||||
|
for c in self.head.childs:
|
||||||
|
opts.append((c, c.getStrongScore(self.head.curPlayer, -1)[0]))
|
||||||
|
opts.sort(key=lambda x: x[1])
|
||||||
|
print('[i] Evaluated Options:')
|
||||||
|
for o in opts:
|
||||||
|
#print('['+str(o[0])+']' + str(o[0].lastAction) + " (Score: "+str(o[1])+")")
|
||||||
|
print('[ ]' + str(o[0].lastAction) + " (Score: "+str(o[1])+")")
|
||||||
|
print('[#] I choose to play: ' + str(opts[0][0].lastAction))
|
||||||
|
self.performAction(opts[0][0].lastAction)
|
||||||
else:
|
else:
|
||||||
self.playerStep()
|
action = choose('What does player '+str(self.head.curPlayer)+' want to do?', self.head.avaibleActions)
|
||||||
|
self.performAction(action)
|
||||||
|
|
||||||
def botStep(self):
|
def game(self, bots=None):
|
||||||
if self.node.walking:
|
if bots==None:
|
||||||
self.node.endWalk()
|
bots = [None]*self.head.playersNum
|
||||||
self.node.expand()
|
|
||||||
self.node = self.node.performBot().node
|
|
||||||
print("[*] Bot did "+str(self.node.lastAction))
|
|
||||||
|
|
||||||
def playerStep(self):
|
|
||||||
self.node.beginWalk()
|
|
||||||
print(self.node)
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
self.turn(bots[self.head.curPlayer])
|
||||||
newNode = self.node.performPlayer(
|
|
||||||
Action(self.node.state.turn, int(input("[#]> "))))
|
|
||||||
except IllegalActionException:
|
|
||||||
print("[!] Illegal Action")
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
self.node.endWalk()
|
|
||||||
self.node = newNode
|
|
||||||
|
|
||||||
|
|
||||||
class NeuralTrainer():
|
|
||||||
def __init__(self, StateClass):
|
|
||||||
self.State = StateClass
|
|
||||||
self.model = self.State.buildModel()
|
|
||||||
|
|
||||||
def train(self, states, scores, rounds=2000):
|
|
||||||
loss_fn = torch.nn.MSELoss(reduction='sum')
|
|
||||||
learning_rate = 1e-6
|
|
||||||
for t in range(rounds):
|
|
||||||
y_pred = self.model(states[t % len(states)])
|
|
||||||
y = scores[t % len(states)]
|
|
||||||
loss = loss_fn(y_pred, y)
|
|
||||||
print(t, loss.item())
|
|
||||||
self.model.zeroGrad()
|
|
||||||
loss.backwards()
|
|
||||||
with torch.no_grad():
|
|
||||||
for param in model.parameters():
|
|
||||||
param -= learning_rate * param.grad
|
|
||||||
|
|
||||||
def setWeights(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def getWeights(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def loadWeights(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def storeWeights(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class SelfPlayDataGen():
|
|
||||||
def __init__(self, StateClass, playersNum, compTime=30):
|
|
||||||
self.State = StateClass
|
|
||||||
self.playersNum = playersNum
|
|
||||||
self.compTime = compTime
|
|
||||||
self.gameStates = []
|
|
||||||
|
|
||||||
def game(self):
|
|
||||||
self.nodes = []
|
|
||||||
for p in range(playersNum):
|
|
||||||
self.nodes.append(Node(self.State(
|
|
||||||
turn=(-p) % self.playersNum, generation=0, playersNum=self.playersNum)))
|
|
||||||
|
|
||||||
while True:
|
|
||||||
if (winner := self.nodes[0].state.checkWin) != None:
|
|
||||||
return winner
|
|
||||||
for n in self.nodes:
|
|
||||||
n.beginWalk()
|
|
||||||
time.sleep(compTime)
|
|
||||||
for n in self.nodes:
|
|
||||||
n.endWalk()
|
|
||||||
self.step()
|
|
||||||
self.gameStates.append(
|
|
||||||
[self.nodes[0].state.getTensor(), self.nodes[0].score])
|
|
||||||
|
|
||||||
def step(self):
|
|
||||||
turn = self.nodes[0].state.turn
|
|
||||||
self.nodes[turn] = self.nodes[turn].performBot()
|
|
||||||
action = self.nodes[turn].lastAction
|
|
||||||
for n in range(self.playersNum):
|
|
||||||
if n != turn:
|
|
||||||
action.player = 0
|
|
||||||
self.nodes[n] = self.nodes[n].performPlayer(action)
|
|
||||||
return self.nodes[0].state.checkWin()
|
|
||||||
|
|
||||||
|
|
||||||
class VacuumDecayException(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class IllegalActionException(VacuumDecayException):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class PerformOnUnexpandedNodeException(VacuumDecayException):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class PerformOnTerminalNodeException(VacuumDecayException):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class IllegalTurnException(VacuumDecayException):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class NotBotsTurnException(IllegalTurnException):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class NotPlayersTurnException(IllegalTurnException):
|
|
||||||
pass
|
|
||||||
|
Loading…
Reference in New Issue
Block a user