Initial commit
This commit is contained in:
commit
a46557a635
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
__pycache__
|
||||||
|
*.~*
|
21
README.md
Normal file
21
README.md
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
# Project vacuumDecay
|
||||||
|
|
||||||
|
Project vacuumDecay is a framework for building AIs for games.
|
||||||
|
Avaible architectures are
|
||||||
|
- those used in Deep Blue (mini-max / expecti-max)
|
||||||
|
- advanced expecti-max exploration based on utility heuristics
|
||||||
|
- those used in AlphaGo Zero (knowledge distilation using neural-networks)
|
||||||
|
|
||||||
|
A new AI is created by subclassing the State-class and defining the following functionality (mycelia.py provies a template):
|
||||||
|
- initialization (generating the gameboard or similar)
|
||||||
|
- getting avaible actions for the current situation (returns an Action-object, which can be subclassed to add additional functionality)
|
||||||
|
- applying an action (the state itself should be immutable, a new state should be returned)
|
||||||
|
- checking for a winning-condition (should return None if game has not yet ended)
|
||||||
|
- (optional) a getter for a string-representation of the current state
|
||||||
|
- (optional) a heuristic for the winning-condition (greatly improves capability)
|
||||||
|
- (optional) a getter for a tensor that describes the current game state (required for knowledge distilation)
|
||||||
|
- (optional) interface to allow a human to select an action
|
||||||
|
|
||||||
|
### Current state of the project
|
||||||
|
It currently does not work and implements none of the named functionality in a working fashion.
|
||||||
|
Experiment for TicTacToe, Dikehiker and an encryption-breaker for rc4 are being worked on.
|
61
dikehiker.py
Normal file
61
dikehiker.py
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
from vacuumDecay import *
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
class TTTState(State):
|
||||||
|
def __init__(self, turn=0, generation=0, playersNum=4, bank=[2904,3135,2563,0], bet=[0]*4):
|
||||||
|
self.turn = turn
|
||||||
|
self.generation = generation
|
||||||
|
self.playersNum = playersNum
|
||||||
|
self.bank = bank
|
||||||
|
self.bet = bet
|
||||||
|
self.alive = [1]*playersNum
|
||||||
|
self.score = self.getScore()
|
||||||
|
|
||||||
|
def mutate(self, action):
|
||||||
|
newBank = np.copy(self.bank)
|
||||||
|
newBet = np.copy(self.bet)
|
||||||
|
newBet[self.turn] = action.data
|
||||||
|
newBank[self.turn] = newBank[self.turn]-max(0,newBet[self.turn])
|
||||||
|
if self.turn == self.playersNum-1:
|
||||||
|
loser = min(range(len(newBet)), key=newBet.__getitem__)
|
||||||
|
winer = max(range(len(newBet)), key=newBet.__getitem__)
|
||||||
|
self.alive[loser] = False
|
||||||
|
newBank[winer]+=500
|
||||||
|
return TTTState(turn=(self.turn+1)%self.playersNum, playersNum=self.playersNum, bank=newBank, bet=newBet)
|
||||||
|
|
||||||
|
def getAvaibleActions(self):
|
||||||
|
if self.alive[self.turn]:
|
||||||
|
for b in range(-self.playersNum-1, self.bank[self.turn]+1):
|
||||||
|
yield Action(self.turn, b)
|
||||||
|
else:
|
||||||
|
yield Action(self.turn, 0)
|
||||||
|
|
||||||
|
def checkWin(self):
|
||||||
|
if sum(self.alive)==1:
|
||||||
|
for p,a in enumerate(self.alive):
|
||||||
|
if a:
|
||||||
|
return p
|
||||||
|
return None
|
||||||
|
|
||||||
|
def getScore(self):
|
||||||
|
return max(self.bank) + sum(self.bank) - self.bank[self.turn]*2
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
s = []
|
||||||
|
for l in range(len(self.bank)):
|
||||||
|
if self.alive[l]:
|
||||||
|
s.append(str(self.bet[l])+' -> '+str(self.bank[l]))
|
||||||
|
else:
|
||||||
|
s.append('<dead>')
|
||||||
|
return "\n".join(s)
|
||||||
|
|
||||||
|
def getTensor(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def getModel():
|
||||||
|
return None
|
||||||
|
|
||||||
|
if __name__=="__main__":
|
||||||
|
vd = WeakSolver(TTTState())
|
||||||
|
vd.selfPlay()
|
72
encBreaker.py
Normal file
72
encBreaker.py
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
from vacuumDecay import *
|
||||||
|
from arc4 import ARC4
|
||||||
|
import copy
|
||||||
|
|
||||||
|
class KnownPlaintextAndKeylen(State, ABC):
|
||||||
|
def __init__(self, plaintext, ciphertext, keyLenBits, keyBits=None, turn=0, generation=0, playersNum=1, lastChange=None):
|
||||||
|
if keyBits==None:
|
||||||
|
keyBits = [0]*keyLenBits
|
||||||
|
self.turn = turn
|
||||||
|
self.generation = generation
|
||||||
|
self.keyBits = keyBits
|
||||||
|
self.keyLenBits = keyLenBits
|
||||||
|
self.plaintext = plaintext
|
||||||
|
self.ciphertext = ciphertext
|
||||||
|
self.lastChange = lastChange
|
||||||
|
self.decrypt = self._decrypt()
|
||||||
|
self.score = self.getScore()
|
||||||
|
|
||||||
|
def mutate(self, action):
|
||||||
|
newKeyBits = copy.copy(self.keyBits)
|
||||||
|
newKeyBits[action.data] = int(not newKeyBits[action.data])
|
||||||
|
return XorKnownPlaintextAndKeylen(self.plaintext, self.ciphertext, self.keyLenBits, newKeyBits, generation=self.generation+1, lastChange = action.data)
|
||||||
|
|
||||||
|
def getAvaibleActions(self):
|
||||||
|
for i in range(self.keyLenBits):
|
||||||
|
#if self.keyBits[i] == 0:
|
||||||
|
if self.lastChange != i:
|
||||||
|
yield Action(0, i)
|
||||||
|
|
||||||
|
def getKey(self):
|
||||||
|
s = ""
|
||||||
|
for i in range(int(self.keyLenBits/8)):
|
||||||
|
s += chr(int("".join([str(c) for c in self.keyBits[i*8:][:8]]),2))
|
||||||
|
return s
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _decrypt(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def checkWin(self):
|
||||||
|
return self.decrypt == self.plaintext
|
||||||
|
|
||||||
|
def getScore(self):
|
||||||
|
diff = sum([bin(ord(a) ^ ord(b)).count("1") for a,b in zip(self.decrypt, self.plaintext)])
|
||||||
|
return diff / (len(self.plaintext)*8)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return "{"+self.getKey()+"}["+self.decrypt+"]"
|
||||||
|
|
||||||
|
def getTensor(self):
|
||||||
|
return torch.tensor(self.keyBits + list(map(int, ''.join([bin(ord(i)).lstrip('0b').rjust(8,'0') for i in self.decrypt]))))
|
||||||
|
|
||||||
|
def getModel(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def getPriority(self, score):
|
||||||
|
return self.score + (1/self.keyLenBits)*0.01*self.generation
|
||||||
|
|
||||||
|
class XorKnownPlaintextAndKeylen(KnownPlaintextAndKeylen):
|
||||||
|
def _decrypt(self):
|
||||||
|
return ''.join(chr(ord(a) ^ ord(b)) for a,b in zip(self.ciphertext, self.getKey()))
|
||||||
|
|
||||||
|
class RC4KnownPlayintextAndKeylen(KnownPlaintextAndKeylen):
|
||||||
|
def _decrypt(self):
|
||||||
|
rc4 = ARC4(self.getKey())
|
||||||
|
return rc4.decrypt(self.ciphertext).decode("ascii")
|
||||||
|
|
||||||
|
if __name__=="__main__":
|
||||||
|
vd = WeakSolver(RC4KnownPlaintextAndKeylen())
|
||||||
|
|
||||||
|
# TODO:
|
||||||
|
# - Should use bytes for everything (not array of ints / string)
|
18
mycelia.py
Normal file
18
mycelia.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
class State():
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Action():
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BotAction():
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PlayerAction():
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class EnvAction():
|
||||||
|
pass
|
61
tictactoe.py
Normal file
61
tictactoe.py
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
from vacuumDecay import *
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
class TTTState(State):
|
||||||
|
def __init__(self, turn=0, generation=0, playersNum=2, board=None):
|
||||||
|
if type(board) == type(None):
|
||||||
|
board = np.array([None]*9)
|
||||||
|
self.turn = turn
|
||||||
|
self.generation = generation
|
||||||
|
self.playersNum = playersNum
|
||||||
|
self.board = board
|
||||||
|
self.score = self.getScore()
|
||||||
|
|
||||||
|
def mutate(self, action):
|
||||||
|
newBoard = np.copy(self.board)
|
||||||
|
newBoard[action.data] = self.turn
|
||||||
|
return TTTState(turn=(self.turn+1)%self.playersNum, playersNum=self.playersNum, board=newBoard)
|
||||||
|
|
||||||
|
def getAvaibleActions(self):
|
||||||
|
for i in range(9):
|
||||||
|
if self.board[i]==None:
|
||||||
|
yield Action(self.turn, i)
|
||||||
|
|
||||||
|
def checkWin(self):
|
||||||
|
s = self.board
|
||||||
|
for i in range(3):
|
||||||
|
if (s[i] == s[i+3] == s[i+6] != None):
|
||||||
|
return s[i]
|
||||||
|
if (s[i*3] == s[i*3+1] == s[i*3+2] != None):
|
||||||
|
return s[i*3]
|
||||||
|
if (s[0] == s[4] == s[8] != None):
|
||||||
|
return s[0]
|
||||||
|
if (s[2] == s[4] == s[6] != None):
|
||||||
|
return s[2]
|
||||||
|
for i in range(9):
|
||||||
|
if s[i] == None:
|
||||||
|
return None
|
||||||
|
return -1
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
s = []
|
||||||
|
for l in range(3):
|
||||||
|
s.append(" ".join([str(p) if p!=None else '.' for p in self.board[l*3:][:3]]))
|
||||||
|
return "\n".join(s)
|
||||||
|
|
||||||
|
def getTensor(self):
|
||||||
|
return torch.tensor([self.turn] + self.board)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def getModel():
|
||||||
|
return torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(10, 10)
|
||||||
|
torch.nn.ReLu()
|
||||||
|
torch.nn.Linear(10, 3)
|
||||||
|
torch.nn.Sigmoid()
|
||||||
|
torch.nn.Linear(3,1)
|
||||||
|
)
|
||||||
|
|
||||||
|
if __name__=="__main__":
|
||||||
|
vd = VacuumDecay(TTTState())
|
||||||
|
vd.weakPlay()
|
432
vacuumDecay.py
Normal file
432
vacuumDecay.py
Normal file
@ -0,0 +1,432 @@
|
|||||||
|
import time
|
||||||
|
import random
|
||||||
|
import threading
|
||||||
|
import torch
|
||||||
|
#from multiprocessing import Event
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from threading import Event
|
||||||
|
from queue import PriorityQueue, Empty
|
||||||
|
|
||||||
|
|
||||||
|
class Action():
|
||||||
|
# Should hold the data representing an action
|
||||||
|
# Actions are applied to a State in State.mutate
|
||||||
|
|
||||||
|
def __init__(self, player, data):
|
||||||
|
self.player = player
|
||||||
|
self.data = data
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
# This should be implemented differently
|
||||||
|
# Two actions of different generations will never be compared
|
||||||
|
if type(other) != type(self):
|
||||||
|
return False
|
||||||
|
return str(self.data) == str(other.data)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
# should return visual representation of this action
|
||||||
|
# should start with < and end with >
|
||||||
|
return "<P"+str(self.player)+"-"+str(self.data)+">"
|
||||||
|
|
||||||
|
class NaiveUniverse():
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def merge(self, branch):
|
||||||
|
return branch
|
||||||
|
|
||||||
|
class BranchUniverse():
|
||||||
|
def __init__(self):
|
||||||
|
self.branches = {}
|
||||||
|
|
||||||
|
def merge(self, branch):
|
||||||
|
tensor = branch.node.state.getTensor()
|
||||||
|
match = self.branches.get(tensor)
|
||||||
|
if match:
|
||||||
|
return match
|
||||||
|
else:
|
||||||
|
self.branches[tensor] = branch
|
||||||
|
|
||||||
|
class Branch():
|
||||||
|
def __new__(self, universe, preState, action): # fancy!
|
||||||
|
self.preState = preState
|
||||||
|
self.action = action
|
||||||
|
postState = preState.mutate(action)
|
||||||
|
self.node = Node(postState, universe=universe,
|
||||||
|
parent=preState, lastAction=action)
|
||||||
|
return universe.merge(self)
|
||||||
|
|
||||||
|
|
||||||
|
class State(ABC):
|
||||||
|
# Hold a representation of the current game-state
|
||||||
|
# Allows retriving avaible actions (getAvaibleActions) and applying them (mutate)
|
||||||
|
# Mutations return a new State and should not have any effect on the current State
|
||||||
|
# Allows checking itself for a win (checkWin) or scoring itself based on a simple heuristic (getScore)
|
||||||
|
# The calculated score should be 0 when won; higher when in a worse state; highest for loosing
|
||||||
|
# getPriority is used for prioritising certain Nodes / States when expanding / walking the tree
|
||||||
|
|
||||||
|
def __init__(self, turn=0, generation=0, playersNum=2):
|
||||||
|
self.turn = turn
|
||||||
|
self.generation = generation
|
||||||
|
self.playersNum = playersNum
|
||||||
|
self.score = self.getScore()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def mutate(self, action):
|
||||||
|
# Returns a new state with supplied action performed
|
||||||
|
# self should not be changed
|
||||||
|
return State(turn=(self.turn+1) % self.playersNum, generation=self.generation+1, playersNum=self.playersNum)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def getAvaibleActions(self):
|
||||||
|
# Should return an array of all possible actions
|
||||||
|
return []
|
||||||
|
|
||||||
|
# improveMe
|
||||||
|
def getPriority(self, score):
|
||||||
|
# Used for ordering the priority queue
|
||||||
|
# Priority should not change for the same root
|
||||||
|
# Lower prioritys get worked on first
|
||||||
|
# Higher generations should have slightly higher priority
|
||||||
|
return score + self.generation*0.1
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def checkWin(self):
|
||||||
|
# -1 -> Draw
|
||||||
|
# None -> Not ended
|
||||||
|
# n e N -> player n won
|
||||||
|
return None
|
||||||
|
|
||||||
|
# improveMe
|
||||||
|
def getScore(self):
|
||||||
|
# 0 <= score <= 1; should return close to zero when we are winning
|
||||||
|
w = self.checkWin()
|
||||||
|
if w == None:
|
||||||
|
return 0.5
|
||||||
|
if w == 0:
|
||||||
|
return 0
|
||||||
|
if w == -1:
|
||||||
|
return 0.9
|
||||||
|
return 1
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __str__(self):
|
||||||
|
# return visual rep of state
|
||||||
|
return "[#]"
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def getTensor(self):
|
||||||
|
return torch.tensor([0])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def getModel():
|
||||||
|
pass
|
||||||
|
|
||||||
|
def getScoreNeural(self):
|
||||||
|
pass
|
||||||
|
return self.model(self.getTensor())
|
||||||
|
|
||||||
|
|
||||||
|
class Node():
|
||||||
|
def __init__(self, state, universe=None, parent=None, lastAction=None, playersNum=2):
|
||||||
|
self.state = state
|
||||||
|
if not universe:
|
||||||
|
universe = NaiveUniverse()
|
||||||
|
# TODO: Maybe add self to new BranchUniverse?
|
||||||
|
self.universe = universe
|
||||||
|
self.parent = parent
|
||||||
|
self.lastAction = lastAction
|
||||||
|
self.playersNum = playersNum
|
||||||
|
|
||||||
|
self.childs = None
|
||||||
|
self.score = state.getScore()
|
||||||
|
self.done = Event()
|
||||||
|
self.threads = []
|
||||||
|
self.walking = False
|
||||||
|
self.alive = True
|
||||||
|
|
||||||
|
def expand(self, shuffle=True):
|
||||||
|
actions = self.state.getAvaibleActions()
|
||||||
|
if self.childs != None:
|
||||||
|
return True
|
||||||
|
self.childs = []
|
||||||
|
for action in actions:
|
||||||
|
self.childs.append(Branch(self.universe, self.state, action))
|
||||||
|
if self.childs == []:
|
||||||
|
return False
|
||||||
|
if shuffle:
|
||||||
|
random.shuffle(self.childs)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _perform(self, action):
|
||||||
|
if self.childs == None:
|
||||||
|
raise PerformOnUnexpandedNodeException()
|
||||||
|
elif self.childs == []:
|
||||||
|
raise PerformOnTerminalNodeException()
|
||||||
|
for child in self.childs:
|
||||||
|
if child.node.lastAction == action:
|
||||||
|
self.endWalk()
|
||||||
|
return child
|
||||||
|
raise IllegalActionException()
|
||||||
|
|
||||||
|
def performBot(self):
|
||||||
|
if self.state.turn != 0:
|
||||||
|
raise NotBotsTurnException()
|
||||||
|
if self.childs == None:
|
||||||
|
raise PerformOnUnexpandedNodeException()
|
||||||
|
if self.childs == []:
|
||||||
|
raise PerformOnTerminalNodeException()
|
||||||
|
if self.walking:
|
||||||
|
self.endWalk()
|
||||||
|
bChild = self.childs[0]
|
||||||
|
for child in self.childs[1:]:
|
||||||
|
if not child:
|
||||||
|
print(self)
|
||||||
|
if child.node.score <= bChild.node.score:
|
||||||
|
bChild = child
|
||||||
|
return bChild
|
||||||
|
|
||||||
|
def performPlayer(self, action):
|
||||||
|
if self.state.turn == 0:
|
||||||
|
raise NotPlayersTurnException()
|
||||||
|
return self._perform(action)
|
||||||
|
|
||||||
|
def getAvaibleActions(self):
|
||||||
|
return self.state.getAvaibleActions()
|
||||||
|
|
||||||
|
def getLastAction(self):
|
||||||
|
return self.lastAction
|
||||||
|
|
||||||
|
def beginWalk(self, threadNum=1):
|
||||||
|
if self.walking:
|
||||||
|
raise Exception("Already Walking")
|
||||||
|
self.walking = True
|
||||||
|
self.queue = PriorityQueue()
|
||||||
|
self.done.clear()
|
||||||
|
self.expand()
|
||||||
|
self._activateEdge()
|
||||||
|
for i in range(threadNum):
|
||||||
|
t = threading.Thread(target=self._worker)
|
||||||
|
t.start()
|
||||||
|
self.threads.append(t)
|
||||||
|
|
||||||
|
def endWalk(self):
|
||||||
|
if not self.walking:
|
||||||
|
raise Exception("Not Walking")
|
||||||
|
self.done.set()
|
||||||
|
for t in self.threads:
|
||||||
|
t.join()
|
||||||
|
self.walking = False
|
||||||
|
|
||||||
|
def walkUntilDone(self):
|
||||||
|
if not self.walking:
|
||||||
|
self.beginWalk()
|
||||||
|
for t in self.threads:
|
||||||
|
t.join()
|
||||||
|
self.done.set()
|
||||||
|
|
||||||
|
def syncWalk(self, time, threads=16):
|
||||||
|
self.beginWalk(threadNum=threadNum)
|
||||||
|
time.sleep(time)
|
||||||
|
self.endWalk()
|
||||||
|
|
||||||
|
def _worker(self):
|
||||||
|
while not self.done.is_set():
|
||||||
|
try:
|
||||||
|
node = self.queue.get_nowait()
|
||||||
|
except Empty:
|
||||||
|
continue
|
||||||
|
if node.alive:
|
||||||
|
if node.expand():
|
||||||
|
node._updateScore()
|
||||||
|
if self.done.is_set():
|
||||||
|
queque.task_done()
|
||||||
|
break
|
||||||
|
if node.state.checkWin == None:
|
||||||
|
for c in node.childs:
|
||||||
|
self.queue.put(c.node)
|
||||||
|
self.queue.task_done()
|
||||||
|
|
||||||
|
def _activateEdge(self, node=None):
|
||||||
|
if node == None:
|
||||||
|
node = self
|
||||||
|
if node.childs == None:
|
||||||
|
self.queue.put(node)
|
||||||
|
elif node.alive:
|
||||||
|
for c in node.childs:
|
||||||
|
self._activateEdge(node=c.node)
|
||||||
|
|
||||||
|
def __lt__(self, other):
|
||||||
|
# Used for ordering the priority queue
|
||||||
|
return self.state.getPriority(self.score) < other.state.getPriority(self.score)
|
||||||
|
|
||||||
|
# improveMe
|
||||||
|
def _calcAggScore(self):
|
||||||
|
if self.childs != None and self.childs != []:
|
||||||
|
scores = [c.node.score for c in self.childs]
|
||||||
|
if self.state.turn == 0:
|
||||||
|
self.score = min(scores)
|
||||||
|
elif self.playersNum == 2:
|
||||||
|
self.score = max(scores)
|
||||||
|
else:
|
||||||
|
# Note: This might be tweaked
|
||||||
|
self.score = (max(scores) + sum(scores)/len(scores)) / 2
|
||||||
|
|
||||||
|
def _updateScore(self):
|
||||||
|
oldScore = self.score
|
||||||
|
self._calcAggScore()
|
||||||
|
if self.score != oldScore:
|
||||||
|
self._pushScore()
|
||||||
|
|
||||||
|
def _pushScore(self):
|
||||||
|
if self.parent != None:
|
||||||
|
self.parent._updateScore()
|
||||||
|
elif self.score == 0:
|
||||||
|
self.done.set()
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
s = []
|
||||||
|
if self.lastAction == None:
|
||||||
|
s.append("[ {ROOT} ]")
|
||||||
|
else:
|
||||||
|
s.append("[ -> "+str(self.lastAction)+" ]")
|
||||||
|
s.append("[ turn: "+str(self.state.turn)+" ]")
|
||||||
|
s.append(str(self.state))
|
||||||
|
s.append("[ score: "+str(self.score)+" ]")
|
||||||
|
return '\n'.join(s)
|
||||||
|
|
||||||
|
|
||||||
|
class WeakSolver():
|
||||||
|
def __init__(self, state):
|
||||||
|
self.node = Node(state)
|
||||||
|
|
||||||
|
def play(self):
|
||||||
|
while self.node.state.checkWin() == None:
|
||||||
|
self.step()
|
||||||
|
print(self.node)
|
||||||
|
print("[*] " + str(self.node.state.checkWin()) + " won!")
|
||||||
|
if self.node.walking:
|
||||||
|
self.node.endWalk()
|
||||||
|
|
||||||
|
def step(self):
|
||||||
|
if self.node.state.turn == 0:
|
||||||
|
self.botStep()
|
||||||
|
else:
|
||||||
|
self.playerStep()
|
||||||
|
|
||||||
|
def botStep(self):
|
||||||
|
if self.node.walking:
|
||||||
|
self.node.endWalk()
|
||||||
|
self.node.expand()
|
||||||
|
self.node = self.node.performBot().node
|
||||||
|
print("[*] Bot did "+str(self.node.lastAction))
|
||||||
|
|
||||||
|
def playerStep(self):
|
||||||
|
self.node.beginWalk()
|
||||||
|
print(self.node)
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
newNode = self.node.performPlayer(
|
||||||
|
Action(self.node.state.turn, int(input("[#]> "))))
|
||||||
|
except IllegalActionException:
|
||||||
|
print("[!] Illegal Action")
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
self.node.endWalk()
|
||||||
|
self.node = newNode
|
||||||
|
|
||||||
|
|
||||||
|
class NeuralTrainer():
|
||||||
|
def __init__(self, StateClass):
|
||||||
|
self.State = StateClass
|
||||||
|
self.model = self.State.buildModel()
|
||||||
|
|
||||||
|
def train(self, states, scores, rounds=2000):
|
||||||
|
loss_fn = torch.nn.MSELoss(reduction='sum')
|
||||||
|
learning_rate = 1e-6
|
||||||
|
for t in range(rounds):
|
||||||
|
y_pred = self.model(states[t % len(states)])
|
||||||
|
y = scores[t % len(states)]
|
||||||
|
loss = loss_fn(y_pred, y)
|
||||||
|
print(t, loss.item())
|
||||||
|
self.model.zeroGrad()
|
||||||
|
loss.backwards()
|
||||||
|
with torch.no_grad():
|
||||||
|
for param in model.parameters():
|
||||||
|
param -= learning_rate * param.grad
|
||||||
|
|
||||||
|
def setWeights(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def getWeights(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def loadWeights(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def storeWeights(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SelfPlayDataGen():
|
||||||
|
def __init__(self, StateClass, playersNum, compTime=30):
|
||||||
|
self.State = StateClass
|
||||||
|
self.playersNum = playersNum
|
||||||
|
self.compTime = compTime
|
||||||
|
self.gameStates = []
|
||||||
|
|
||||||
|
def game(self):
|
||||||
|
self.nodes = []
|
||||||
|
for p in range(playersNum):
|
||||||
|
self.nodes.append(Node(self.State(
|
||||||
|
turn=(-p) % self.playersNum, generation=0, playersNum=self.playersNum)))
|
||||||
|
|
||||||
|
while True:
|
||||||
|
if (winner := self.nodes[0].state.checkWin) != None:
|
||||||
|
return winner
|
||||||
|
for n in self.nodes:
|
||||||
|
n.beginWalk()
|
||||||
|
time.sleep(compTime)
|
||||||
|
for n in self.nodes:
|
||||||
|
n.endWalk()
|
||||||
|
self.step()
|
||||||
|
self.gameStates.append(
|
||||||
|
[self.nodes[0].state.getTensor(), self.nodes[0].score])
|
||||||
|
|
||||||
|
def step(self):
|
||||||
|
turn = self.nodes[0].state.turn
|
||||||
|
self.nodes[turn] = self.nodes[turn].performBot()
|
||||||
|
action = self.nodes[turn].lastAction
|
||||||
|
for n in range(self.playersNum):
|
||||||
|
if n != turn:
|
||||||
|
action.player = 0
|
||||||
|
self.nodes[n] = self.nodes[n].performPlayer(action)
|
||||||
|
return self.nodes[0].state.checkWin()
|
||||||
|
|
||||||
|
|
||||||
|
class VacuumDecayException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class IllegalActionException(VacuumDecayException):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PerformOnUnexpandedNodeException(VacuumDecayException):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PerformOnTerminalNodeException(VacuumDecayException):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class IllegalTurnException(VacuumDecayException):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class NotBotsTurnException(IllegalTurnException):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class NotPlayersTurnException(IllegalTurnException):
|
||||||
|
pass
|
Loading…
Reference in New Issue
Block a user