This commit is contained in:
Dominik Moritz Roth 2022-04-13 22:49:38 +02:00
parent a46557a635
commit 5ba277a2aa
2 changed files with 208 additions and 313 deletions

View File

@ -2,24 +2,23 @@ from vacuumDecay import *
import numpy as np import numpy as np
class TTTState(State): class TTTState(State):
def __init__(self, turn=0, generation=0, playersNum=2, board=None): def __init__(self, curPlayer=0, generation=0, playersNum=2, board=None):
if type(board) == type(None): if type(board) == type(None):
board = np.array([None]*9) board = np.array([None]*9)
self.turn = turn self.curPlayer = curPlayer
self.generation = generation self.generation = generation
self.playersNum = playersNum self.playersNum = playersNum
self.board = board self.board = board
self.score = self.getScore()
def mutate(self, action): def mutate(self, action):
newBoard = np.copy(self.board) newBoard = np.copy(self.board)
newBoard[action.data] = self.turn newBoard[action.data] = self.curPlayer
return TTTState(turn=(self.turn+1)%self.playersNum, playersNum=self.playersNum, board=newBoard) return TTTState(curPlayer=(self.curPlayer+1)%self.playersNum, playersNum=self.playersNum, board=newBoard)
def getAvaibleActions(self): def getAvaibleActions(self):
for i in range(9): for i in range(9):
if self.board[i]==None: if self.board[i]==None:
yield Action(self.turn, i) yield Action(self.curPlayer, i)
def checkWin(self): def checkWin(self):
s = self.board s = self.board
@ -49,13 +48,13 @@ class TTTState(State):
@classmethod @classmethod
def getModel(): def getModel():
return torch.nn.Sequential( return torch.nn.Sequential(
torch.nn.Linear(10, 10) torch.nn.Linear(10, 10),
torch.nn.ReLu() torch.nn.ReLu(),
torch.nn.Linear(10, 3) torch.nn.Linear(10, 3),
torch.nn.Sigmoid() torch.nn.Sigmoid(),
torch.nn.Linear(3,1) torch.nn.Linear(3,1)
) )
if __name__=="__main__": if __name__=="__main__":
vd = VacuumDecay(TTTState()) run = Runtime(TTTState())
vd.weakPlay() run.game()

View File

@ -28,33 +28,37 @@ class Action():
# should start with < and end with > # should start with < and end with >
return "<P"+str(self.player)+"-"+str(self.data)+">" return "<P"+str(self.player)+"-"+str(self.data)+">"
class NaiveUniverse(): class Universe():
def __init__(self): def newOpen(self, node):
pass pass
def merge(self, branch): def merge(self, node):
return branch return node
class BranchUniverse(): def clearPQ(self):
pass
def iter(self):
return []
def activateEdge(self, head):
pass
class QueueingUniverse(Universe):
def __init__(self): def __init__(self):
self.branches = {} self.pq = []
def merge(self, branch): def newOpen(self, node):
tensor = branch.node.state.getTensor() heapq.headpush(self.pq, (node.priority, node))
match = self.branches.get(tensor)
if match:
return match
else:
self.branches[tensor] = branch
class Branch(): def clearPQ(self):
def __new__(self, universe, preState, action): # fancy! self.pq = []
self.preState = preState
self.action = action def iter(self):
postState = preState.mutate(action) yield heapq.heappop(self.pq)
self.node = Node(postState, universe=universe,
parent=preState, lastAction=action) def activateEdge(self, head):
return universe.merge(self) head._activateEdge()
class State(ABC): class State(ABC):
@ -65,17 +69,16 @@ class State(ABC):
# The calculated score should be 0 when won; higher when in a worse state; highest for loosing # The calculated score should be 0 when won; higher when in a worse state; highest for loosing
# getPriority is used for prioritising certain Nodes / States when expanding / walking the tree # getPriority is used for prioritising certain Nodes / States when expanding / walking the tree
def __init__(self, turn=0, generation=0, playersNum=2): def __init__(self, curPlayer=0, generation=0, playersNum=2):
self.turn = turn self.curPlayer = curPlayer
self.generation = generation self.generation = generation
self.playersNum = playersNum self.playersNum = playersNum
self.score = self.getScore()
@abstractmethod @abstractmethod
def mutate(self, action): def mutate(self, action):
# Returns a new state with supplied action performed # Returns a new state with supplied action performed
# self should not be changed # self should not be changed
return State(turn=(self.turn+1) % self.playersNum, generation=self.generation+1, playersNum=self.playersNum) return State(curPlayer=(self.curPlayer+1) % self.playersNum, generation=self.generation+1, playersNum=self.playersNum)
@abstractmethod @abstractmethod
def getAvaibleActions(self): def getAvaibleActions(self):
@ -87,8 +90,8 @@ class State(ABC):
# Used for ordering the priority queue # Used for ordering the priority queue
# Priority should not change for the same root # Priority should not change for the same root
# Lower prioritys get worked on first # Lower prioritys get worked on first
# Higher generations should have slightly higher priority # Higher generations should have higher priority
return score + self.generation*0.1 return score + self.generation*0.5
@abstractmethod @abstractmethod
def checkWin(self): def checkWin(self):
@ -98,12 +101,12 @@ class State(ABC):
return None return None
# improveMe # improveMe
def getScore(self): def getScoreFor(self, player):
# 0 <= score <= 1; should return close to zero when we are winning # 0 <= score <= 1; should return close to zero when we are winning
w = self.checkWin() w = self.checkWin()
if w == None: if w == None:
return 0.5 return 0.5
if w == 0: if w == player:
return 0 return 0
if w == -1: if w == -1:
return 0.9 return 0.9
@ -115,7 +118,7 @@ class State(ABC):
return "[#]" return "[#]"
@abstractmethod @abstractmethod
def getTensor(self): def getTensor(self, phase='default'):
return torch.tensor([0]) return torch.tensor([0])
@classmethod @classmethod
@ -123,166 +126,131 @@ class State(ABC):
pass pass
def getScoreNeural(self): def getScoreNeural(self):
pass
return self.model(self.getTensor()) return self.model(self.getTensor())
class Node(): class Node():
def __init__(self, state, universe=None, parent=None, lastAction=None, playersNum=2): def __init__(self, state, universe=None, parent=None, lastAction=None):
self.state = state self.state = state
if not universe: if universe==None:
universe = NaiveUniverse() universe = Universe()
# TODO: Maybe add self to new BranchUniverse?
self.universe = universe self.universe = universe
self.parent = parent self.parent = parent
self.lastAction = lastAction self.lastAction = lastAction
self.playersNum = playersNum
self.childs = None self._childs = None
self.score = state.getScore() self._scores = [None]*self.state.playersNum
self.done = Event() self._strongs = [None]*self.state.playersNum
self.threads = [] self._alive = True
self.walking = False
self.alive = True
def expand(self, shuffle=True): def kill(self):
self._alive = False
@property
def childs(self):
if self._childs == None:
self._expand()
return self._childs
def _expand(self):
self._childs = []
actions = self.state.getAvaibleActions() actions = self.state.getAvaibleActions()
if self.childs != None:
return True
self.childs = []
for action in actions: for action in actions:
self.childs.append(Branch(self.universe, self.state, action)) newNode = Node(self.state.mutate(action), self.universe, self, action)
if self.childs == []: self._childs.append(self.universe.merge(newNode))
return False
if shuffle: @property
random.shuffle(self.childs) def strongs(self):
return self._strongs
def _pullStrong(self): # Currently Expecti-Max
strongs = [None]*self.playersNum
for p in range(self.playersNum):
cp = self.state.curPlayer
if cp == p: # P owns the turn; controlls outcome
best = 10000000
for c in self.childs:
if c._strongs[cp] < best:
best = c._strongs[p]
strongs[p] = best
else:
scos = [(c._strongs[cp], c._strongs[p]) for c in self.childs]
scos.sort(key=lambda x: x[0])
betterHalf = scos[:max(3,int(len(scos)/2))]
myScores = [bh[1] for bh in betterHalf]
strongs[p] = sum(myScores)/len(myScores)
update = False
for s in range(self.playersNum):
if strongs[s] != self._strongs[s]:
update = True
break
self._strongs = strongs
if update:
self.parent._pullStrong()
def forceStrong(self, depth=3):
if depth==0:
self.strongDecay()
else:
for c in self.childs:
c.forceStrong(depth-1)
def strongDecay(self):
if self._strongs == [None]*self.playersNum:
if not self.scoresAvaible():
self._calcScores()
self._strongs = self._scores
self.parent._pullStrong()
def getSelfScore(self):
return self.getScoreFor(self.curPlayer)
def getScoreFor(self, player):
if self._scores[player] == None:
self._calcScore(player)
return self._scores[player]
def scoreAvaible(self, player):
return self._scores[player] != None
def scoresAvaible(self):
for p in self._scores:
if p==None:
return False
return True return True
def _perform(self, action): def _calcScores(self):
if self.childs == None: for p in range(self.state.playersNum):
raise PerformOnUnexpandedNodeException() self._calcScore(p)
elif self.childs == []:
raise PerformOnTerminalNodeException()
for child in self.childs:
if child.node.lastAction == action:
self.endWalk()
return child
raise IllegalActionException()
def performBot(self): def _calcScore(self, player):
if self.state.turn != 0: self._scores[player] = self.state.getScoreFor(player)
raise NotBotsTurnException()
if self.childs == None:
raise PerformOnUnexpandedNodeException()
if self.childs == []:
raise PerformOnTerminalNodeException()
if self.walking:
self.endWalk()
bChild = self.childs[0]
for child in self.childs[1:]:
if not child:
print(self)
if child.node.score <= bChild.node.score:
bChild = child
return bChild
def performPlayer(self, action): @property
if self.state.turn == 0: def priority(self):
raise NotPlayersTurnException() return self.state.getPriority(self.score)
return self._perform(action)
def getAvaibleActions(self): @property
return self.state.getAvaibleActions() def playersNum(self):
return self.state.playersNum
def getLastAction(self): @property
return self.lastAction def avaibleActions(self):
r = []
for c in self.childs:
r.append(c.lastAction)
return r
def beginWalk(self, threadNum=1): @property
if self.walking: def curPlayer(self):
raise Exception("Already Walking") return self.state.curPlayer
self.walking = True
self.queue = PriorityQueue()
self.done.clear()
self.expand()
self._activateEdge()
for i in range(threadNum):
t = threading.Thread(target=self._worker)
t.start()
self.threads.append(t)
def endWalk(self): def _activateEdge(self):
if not self.walking: if not self.strongScoresAvaible():
raise Exception("Not Walking") self.universe.newOpen(self)
self.done.set() else:
for t in self.threads: for c in self.childs:
t.join() c._activateEdge()
self.walking = False
def walkUntilDone(self):
if not self.walking:
self.beginWalk()
for t in self.threads:
t.join()
self.done.set()
def syncWalk(self, time, threads=16):
self.beginWalk(threadNum=threadNum)
time.sleep(time)
self.endWalk()
def _worker(self):
while not self.done.is_set():
try:
node = self.queue.get_nowait()
except Empty:
continue
if node.alive:
if node.expand():
node._updateScore()
if self.done.is_set():
queque.task_done()
break
if node.state.checkWin == None:
for c in node.childs:
self.queue.put(c.node)
self.queue.task_done()
def _activateEdge(self, node=None):
if node == None:
node = self
if node.childs == None:
self.queue.put(node)
elif node.alive:
for c in node.childs:
self._activateEdge(node=c.node)
def __lt__(self, other):
# Used for ordering the priority queue
return self.state.getPriority(self.score) < other.state.getPriority(self.score)
# improveMe
def _calcAggScore(self):
if self.childs != None and self.childs != []:
scores = [c.node.score for c in self.childs]
if self.state.turn == 0:
self.score = min(scores)
elif self.playersNum == 2:
self.score = max(scores)
else:
# Note: This might be tweaked
self.score = (max(scores) + sum(scores)/len(scores)) / 2
def _updateScore(self):
oldScore = self.score
self._calcAggScore()
if self.score != oldScore:
self._pushScore()
def _pushScore(self):
if self.parent != None:
self.parent._updateScore()
elif self.score == 0:
self.done.set()
def __str__(self): def __str__(self):
s = [] s = []
@ -290,143 +258,71 @@ class Node():
s.append("[ {ROOT} ]") s.append("[ {ROOT} ]")
else: else:
s.append("[ -> "+str(self.lastAction)+" ]") s.append("[ -> "+str(self.lastAction)+" ]")
s.append("[ turn: "+str(self.state.turn)+" ]") s.append("[ turn: "+str(self.state.curPlayer)+" ]")
s.append(str(self.state)) s.append(str(self.state))
s.append("[ score: "+str(self.score)+" ]") s.append("[ score: "+str(self.getSelfScore())+" ]")
return '\n'.join(s) return '\n'.join(s)
def choose(txt, options):
while True:
print('[*] '+txt)
for num,opt in enumerate(options):
print('['+str(num+1)+'] ' + str(opt))
inp = input('[> ')
try:
n = int(inp)
if n in range(1,len(options)+1):
return options[n-1]
except:
pass
for opt in options:
if inp==str(opt):
return opt
if len(inp)==1:
for opt in options:
if inp==str(opt)[0]:
return opt
print('[!] Invalid Input.')
class WeakSolver(): class Runtime():
def __init__(self, state): def __init__(self, initState):
self.node = Node(state) self.head = Node(initState)
def play(self): def performAction(self, action):
while self.node.state.checkWin() == None: for c in self.head.childs:
self.step() if action == c.lastAction:
print(self.node) self.head.universe.clearPQ()
print("[*] " + str(self.node.state.checkWin()) + " won!") self.head.kill()
if self.node.walking: self.head = c
self.node.endWalk() self.head.universe.activateEdge(self.head)
return
raise Exception('No such action avaible...')
def step(self): def turn(self, bot=None):
if self.node.state.turn == 0: print(str(self.head))
self.botStep() if bot==None:
c = choose('?', ['human', 'bot', 'undo'])
if c=='undo':
self.head = self.head.parent
return
bot = c=='bot'
if bot:
opts = []
for c in self.head.childs:
opts.append((c, c.getStrongScore(self.head.curPlayer, -1)[0]))
opts.sort(key=lambda x: x[1])
print('[i] Evaluated Options:')
for o in opts:
#print('['+str(o[0])+']' + str(o[0].lastAction) + " (Score: "+str(o[1])+")")
print('[ ]' + str(o[0].lastAction) + " (Score: "+str(o[1])+")")
print('[#] I choose to play: ' + str(opts[0][0].lastAction))
self.performAction(opts[0][0].lastAction)
else: else:
self.playerStep() action = choose('What does player '+str(self.head.curPlayer)+' want to do?', self.head.avaibleActions)
self.performAction(action)
def botStep(self): def game(self, bots=None):
if self.node.walking: if bots==None:
self.node.endWalk() bots = [None]*self.head.playersNum
self.node.expand()
self.node = self.node.performBot().node
print("[*] Bot did "+str(self.node.lastAction))
def playerStep(self):
self.node.beginWalk()
print(self.node)
while True: while True:
try: self.turn(bots[self.head.curPlayer])
newNode = self.node.performPlayer(
Action(self.node.state.turn, int(input("[#]> "))))
except IllegalActionException:
print("[!] Illegal Action")
else:
break
self.node.endWalk()
self.node = newNode
class NeuralTrainer():
def __init__(self, StateClass):
self.State = StateClass
self.model = self.State.buildModel()
def train(self, states, scores, rounds=2000):
loss_fn = torch.nn.MSELoss(reduction='sum')
learning_rate = 1e-6
for t in range(rounds):
y_pred = self.model(states[t % len(states)])
y = scores[t % len(states)]
loss = loss_fn(y_pred, y)
print(t, loss.item())
self.model.zeroGrad()
loss.backwards()
with torch.no_grad():
for param in model.parameters():
param -= learning_rate * param.grad
def setWeights(self):
pass
def getWeights(self):
pass
def loadWeights(self):
pass
def storeWeights(self):
pass
class SelfPlayDataGen():
def __init__(self, StateClass, playersNum, compTime=30):
self.State = StateClass
self.playersNum = playersNum
self.compTime = compTime
self.gameStates = []
def game(self):
self.nodes = []
for p in range(playersNum):
self.nodes.append(Node(self.State(
turn=(-p) % self.playersNum, generation=0, playersNum=self.playersNum)))
while True:
if (winner := self.nodes[0].state.checkWin) != None:
return winner
for n in self.nodes:
n.beginWalk()
time.sleep(compTime)
for n in self.nodes:
n.endWalk()
self.step()
self.gameStates.append(
[self.nodes[0].state.getTensor(), self.nodes[0].score])
def step(self):
turn = self.nodes[0].state.turn
self.nodes[turn] = self.nodes[turn].performBot()
action = self.nodes[turn].lastAction
for n in range(self.playersNum):
if n != turn:
action.player = 0
self.nodes[n] = self.nodes[n].performPlayer(action)
return self.nodes[0].state.checkWin()
class VacuumDecayException(Exception):
pass
class IllegalActionException(VacuumDecayException):
pass
class PerformOnUnexpandedNodeException(VacuumDecayException):
pass
class PerformOnTerminalNodeException(VacuumDecayException):
pass
class IllegalTurnException(VacuumDecayException):
pass
class NotBotsTurnException(IllegalTurnException):
pass
class NotPlayersTurnException(IllegalTurnException):
pass