Many shashed commits
This commit is contained in:
parent
f210c4f670
commit
6262aea6f0
25
README.md
25
README.md
@ -2,21 +2,24 @@
|
|||||||
|
|
||||||
Project vacuumDecay is a framework for building AIs for games.
|
Project vacuumDecay is a framework for building AIs for games.
|
||||||
Avaible architectures are
|
Avaible architectures are
|
||||||
- those used in Deep Blue (mini-max / expecti-max)
|
|
||||||
- advanced expecti-max exploration based on utility heuristics
|
- those used in Deep Blue (mini-max / expecti-max)
|
||||||
- those used in AlphaGo Zero (knowledge distilation using neural-networks)
|
- advanced expecti-max exploration based on utility heuristics
|
||||||
|
- those used in AlphaGo Zero (knowledge distilation using neural-networks)
|
||||||
|
|
||||||
A new AI is created by subclassing the State-class and defining the following functionality (mycelia.py provies a template):
|
A new AI is created by subclassing the State-class and defining the following functionality (mycelia.py provies a template):
|
||||||
- initialization (generating the gameboard or similar)
|
|
||||||
- getting avaible actions for the current situation (returns an Action-object, which can be subclassed to add additional functionality)
|
- initialization (generating the gameboard or similar)
|
||||||
- applying an action (the state itself should be immutable, a new state should be returned)
|
- getting avaible actions for the current situation (returns an Action-object, which can be subclassed to add additional functionality)
|
||||||
- checking for a winning-condition (should return None if game has not yet ended)
|
- applying an action (the state itself should be immutable, a new state should be returned)
|
||||||
- (optional) a getter for a string-representation of the current state
|
- checking for a winning-condition (should return None if game has not yet ended)
|
||||||
- (optional) a heuristic for the winning-condition (greatly improves capability)
|
- (optional) a getter for a string-representation of the current state
|
||||||
- (optional) a getter for a tensor that describes the current game state (required for knowledge distilation)
|
- (optional) a heuristic for the winning-condition (greatly improves capability for expecti-max)
|
||||||
- (optional) interface to allow a human to select an action
|
- (optional) a getter for a tensor that describes the current game state (required for knowledge distilation)
|
||||||
|
- (optional) interface to allow a human to select an action
|
||||||
|
|
||||||
### Current state of the project
|
### Current state of the project
|
||||||
|
|
||||||
The only thing that currently works is the AI for Ultimate TicTacToe.
|
The only thing that currently works is the AI for Ultimate TicTacToe.
|
||||||
It uses a trained neural heuristic (neuristic)
|
It uses a trained neural heuristic (neuristic)
|
||||||
You can train it or play against it (will also train it) using 'python ultimatetictactoe.py'
|
You can train it or play against it (will also train it) using 'python ultimatetictactoe.py'
|
||||||
|
@ -7,6 +7,7 @@ name = "vacuumDecay"
|
|||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"torch",
|
"torch",
|
||||||
|
"numpy",
|
||||||
"flask",
|
"flask",
|
||||||
"flask-socketio",
|
"flask-socketio",
|
||||||
"networkx",
|
"networkx",
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from vacuumDecay.runtime import Runtime, NeuralRuntime, Trainer
|
from vacuumDecay.runtime import Runtime, NeuralRuntime, Trainer
|
||||||
from vacuumDecay.base import Node, Action, Universe, QueueingUniverse
|
from vacuumDecay.base import Node, State, Action, Universe, QueueingUniverse
|
||||||
from vacuumDecay.utils import choose
|
from vacuumDecay.utils import choose
|
||||||
from vacuumDecay.run import main
|
from vacuumDecay.run import main
|
@ -1,19 +1,24 @@
|
|||||||
import torch
|
import torch
|
||||||
|
import time
|
||||||
|
import random
|
||||||
|
from math import sqrt
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from queue import PriorityQueue, Empty
|
from queue import PriorityQueue, Empty
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from vacuumDecay.utils import choose
|
from vacuumDecay.utils import choose
|
||||||
|
|
||||||
class Action():
|
class Action():
|
||||||
# Should hold the data representing an action
|
# Should hold the data representing an action
|
||||||
# Actions are applied to a State in State.mutate
|
# Actions are applied to a State in State.mutate
|
||||||
|
|
||||||
def __init__(self, player, data):
|
def __init__(self, player, data):
|
||||||
self.player = player
|
self.player = player
|
||||||
self.data = data
|
self.data = data
|
||||||
|
|
||||||
|
# ImproveMe
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
# This should be implemented differently
|
# This should be implemented differently
|
||||||
# Two actions of different generations will never be compared
|
# Two actions of different generations will never be compared
|
||||||
@ -21,23 +26,33 @@ class Action():
|
|||||||
return False
|
return False
|
||||||
return str(self.data) == str(other.data)
|
return str(self.data) == str(other.data)
|
||||||
|
|
||||||
|
# ImproveMe
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
# should return visual representation of this action
|
# should return visual representation of this action
|
||||||
# should start with < and end with >
|
# should start with < and end with >
|
||||||
return "<P"+str(self.player)+"-"+str(self.data)+">"
|
return "<P"+str(self.player)+"-"+str(self.data)+">"
|
||||||
|
|
||||||
|
# ImproveMe
|
||||||
def getImage(self, state):
|
def getImage(self, state):
|
||||||
# Should return an image representation of this action given the current state
|
# Should return an image representation of this action given the current state
|
||||||
# Return None if not implemented
|
# Return None if not implemented
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# ImproveMe
|
||||||
|
def getTensor(self, state, player=None):
|
||||||
|
# Should return a complete description of the action (including previous state)
|
||||||
|
# This default will work, but may be suboptimal...
|
||||||
|
return (state.getTensor(), state.mutate(self).getTensor())
|
||||||
|
|
||||||
class State(ABC):
|
class State(ABC):
|
||||||
# Hold a representation of the current game-state
|
# Hold a representation of the current game-state
|
||||||
# Allows retriving avaible actions (getAvaibleActions) and applying them (mutate)
|
# Allows retriving avaible actions (getAvaibleActions) and applying them (mutate)
|
||||||
# Mutations return a new State and should not have any effect on the current State
|
# Mutations return a new State and should not have any effect on the current State
|
||||||
# Allows checking itself for a win (checkWin) or scoring itself based on a simple heuristic (getScore)
|
# Allows checking itself for a win (checkWin) or scoring itself based on a simple heuristic (getScore)
|
||||||
# The calculated score should be 0 when won; higher when in a worse state; highest for loosing
|
# The calculated score should be 0 when won; higher when in a worse state; highest for loosing
|
||||||
# getPriority is used for prioritising certain Nodes / States when expanding / walking the tree
|
# getPriority is used for prioritising certain Nodes / States when expanding / walking the tree (TODO: Remove)
|
||||||
|
|
||||||
|
# Abstract Methodas need to be overrieden, improveMe methods can be overrieden
|
||||||
|
|
||||||
def __init__(self, curPlayer=0, generation=0, playersNum=2):
|
def __init__(self, curPlayer=0, generation=0, playersNum=2):
|
||||||
self.curPlayer = curPlayer
|
self.curPlayer = curPlayer
|
||||||
@ -81,10 +96,10 @@ class State(ABC):
|
|||||||
if w == None:
|
if w == None:
|
||||||
return 0.5
|
return 0.5
|
||||||
if w == player:
|
if w == player:
|
||||||
return 0
|
|
||||||
if w == -1:
|
|
||||||
return 0.9
|
|
||||||
return 1
|
return 1
|
||||||
|
if w == -1:
|
||||||
|
return 0.1
|
||||||
|
return 0
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
@ -92,23 +107,40 @@ class State(ABC):
|
|||||||
return "[#]"
|
return "[#]"
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def getTensor(self, player=None, phase='default'):
|
def getTensor(self, player=None):
|
||||||
if player == None:
|
if player == None:
|
||||||
player = self.curPlayer
|
player = self.curPlayer
|
||||||
return torch.tensor([0])
|
return torch.tensor([0])
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def getModel(cls, phase='default'):
|
def getVModel(cls):
|
||||||
|
# input will be output from state.getTensor
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def getScoreNeural(self, model, player=None, phase='default'):
|
#improveMe
|
||||||
return model(self.getTensor(player=player, phase=phase)).item()
|
def getQModel(cls):
|
||||||
|
# input will be output from action.getTensor
|
||||||
|
return DefaultQ(cls.getVModel())
|
||||||
|
|
||||||
|
def getScoreNeural(self, model, player=None):
|
||||||
|
return model(self.getTensor(player=player)).item()
|
||||||
|
|
||||||
|
# improveMe
|
||||||
def getImage(self):
|
def getImage(self):
|
||||||
# Should return an image representation of this state
|
# Should return an image representation of this state
|
||||||
# Return None if not implemented
|
# Return None if not implemented
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
class DefaultQ(nn.Module):
|
||||||
|
def __init__(self, vModel):
|
||||||
|
super().__init__()
|
||||||
|
self.V = vModel
|
||||||
|
|
||||||
|
def forward(self, inp):
|
||||||
|
s, s_prime = inp
|
||||||
|
v, v_prime = self.V(s), self.V(s_prime)
|
||||||
|
return F.sigmoid(v_prime - v)
|
||||||
|
|
||||||
class Universe():
|
class Universe():
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.scoreProvider = 'naive'
|
self.scoreProvider = 'naive'
|
||||||
@ -160,3 +192,208 @@ class QueueingUniverse(Universe):
|
|||||||
|
|
||||||
def activateEdge(self, head):
|
def activateEdge(self, head):
|
||||||
head._activateEdge()
|
head._activateEdge()
|
||||||
|
|
||||||
|
class Node:
|
||||||
|
def __init__(self, state, universe=None, parent=None, lastAction=None):
|
||||||
|
self.state = state
|
||||||
|
if universe == None:
|
||||||
|
print('[!] No Universe defined. Spawning one...')
|
||||||
|
universe = Universe()
|
||||||
|
self.universe = universe
|
||||||
|
self.parent = parent
|
||||||
|
self.lastAction = lastAction
|
||||||
|
|
||||||
|
self._childs = None
|
||||||
|
self._scores = [None]*self.state.playersNum
|
||||||
|
self._strongs = [None]*self.state.playersNum
|
||||||
|
self._alive = True
|
||||||
|
self._cascadeMemory = 0 # Used for our alternative to alpha-beta pruning
|
||||||
|
self._winner = -2
|
||||||
|
|
||||||
|
self.leaf = True
|
||||||
|
self.last_updated = time.time() # New attribute
|
||||||
|
|
||||||
|
def mark_update(self):
|
||||||
|
self.last_updated = time.time()
|
||||||
|
|
||||||
|
def kill(self):
|
||||||
|
self._alive = False
|
||||||
|
|
||||||
|
def revive(self):
|
||||||
|
self._alive = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def childs(self):
|
||||||
|
if self._childs == None:
|
||||||
|
self._expand()
|
||||||
|
return self._childs
|
||||||
|
|
||||||
|
def _expand(self):
|
||||||
|
self.leaf = False
|
||||||
|
self._childs = []
|
||||||
|
actions = self.state.getAvaibleActions()
|
||||||
|
for action in actions:
|
||||||
|
newNode = Node(self.state.mutate(action),
|
||||||
|
self.universe, self, action)
|
||||||
|
self._childs.append(self.universe.merge(newNode))
|
||||||
|
self.mark_update()
|
||||||
|
|
||||||
|
def getStrongFor(self, player):
|
||||||
|
if self._strongs[player] != None:
|
||||||
|
return self._strongs[player]
|
||||||
|
else:
|
||||||
|
return self.getScoreFor(player)
|
||||||
|
|
||||||
|
def _pullStrong(self):
|
||||||
|
strongs = [None]*self.playersNum
|
||||||
|
has_winner = self.getWinner() != None
|
||||||
|
for p in range(self.playersNum):
|
||||||
|
cp = self.state.curPlayer
|
||||||
|
if has_winner:
|
||||||
|
strongs[p] = self.getScoreFor(p)
|
||||||
|
elif cp == p:
|
||||||
|
best = float('-inf')
|
||||||
|
for c in self.childs:
|
||||||
|
if c.getStrongFor(p) > best:
|
||||||
|
best = c.getStrongFor(p)
|
||||||
|
strongs[p] = best
|
||||||
|
else:
|
||||||
|
scos = [(c.getStrongFor(p), c.getStrongFor(cp)) for c in self.childs]
|
||||||
|
scos.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
betterHalf = [sco for sco, osc in scos[:max(3, int(len(scos)/2))]]
|
||||||
|
strongs[p] = betterHalf[0]*0.9 + sum(betterHalf)/(len(betterHalf))*0.1
|
||||||
|
update = False
|
||||||
|
for s in range(self.playersNum):
|
||||||
|
if strongs[s] != self._strongs[s]:
|
||||||
|
update = True
|
||||||
|
break
|
||||||
|
self._strongs = strongs
|
||||||
|
if update:
|
||||||
|
if self.parent != None:
|
||||||
|
cascade = self.parent._pullStrong()
|
||||||
|
else:
|
||||||
|
cascade = 2
|
||||||
|
self._cascadeMemory = self._cascadeMemory/2 + cascade
|
||||||
|
self.mark_update()
|
||||||
|
return cascade + 1
|
||||||
|
self._cascadeMemory /= 2
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def forceStrong(self, depth=3):
|
||||||
|
if depth == 0:
|
||||||
|
self.strongDecay()
|
||||||
|
else:
|
||||||
|
if len(self.childs):
|
||||||
|
for c in self.childs:
|
||||||
|
c.forceStrong(depth-1)
|
||||||
|
else:
|
||||||
|
self.strongDecay()
|
||||||
|
|
||||||
|
def decayEvent(self):
|
||||||
|
for c in self.childs:
|
||||||
|
c.strongDecay()
|
||||||
|
|
||||||
|
def strongDecay(self):
|
||||||
|
if self._strongs == [None]*self.playersNum:
|
||||||
|
if not self.scoresAvaible():
|
||||||
|
self._calcScores()
|
||||||
|
self._strongs = self._scores
|
||||||
|
if self.parent:
|
||||||
|
return self.parent._pullStrong()
|
||||||
|
return 1
|
||||||
|
return None
|
||||||
|
|
||||||
|
def getSelfScore(self):
|
||||||
|
return self.getScoreFor(self.curPlayer)
|
||||||
|
|
||||||
|
def getScoreFor(self, player):
|
||||||
|
if self._scores[player] == None:
|
||||||
|
self._calcScore(player)
|
||||||
|
return self._scores[player]
|
||||||
|
|
||||||
|
def scoreAvaible(self, player):
|
||||||
|
return self._scores[player] != None
|
||||||
|
|
||||||
|
def scoresAvaible(self):
|
||||||
|
for p in self._scores:
|
||||||
|
if p == None:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def strongScoresAvaible(self):
|
||||||
|
for p in self._strongs:
|
||||||
|
if p == None:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def askUserForAction(self):
|
||||||
|
return self.state.askUserForAction(self.avaibleActions)
|
||||||
|
|
||||||
|
def _calcScores(self):
|
||||||
|
for p in range(self.state.playersNum):
|
||||||
|
self._calcScore(p)
|
||||||
|
|
||||||
|
def _calcScore(self, player):
|
||||||
|
winner = self._getWinner()
|
||||||
|
if winner != None:
|
||||||
|
if winner == player:
|
||||||
|
self._scores[player] = 1.0
|
||||||
|
elif winner == -1:
|
||||||
|
self._scores[player] = 0.1
|
||||||
|
else:
|
||||||
|
self._scores[player] = 0.0
|
||||||
|
return
|
||||||
|
if self.universe.scoreProvider == 'naive':
|
||||||
|
self._scores[player] = self.state.getScoreFor(player)
|
||||||
|
elif self.universe.scoreProvider == 'neural':
|
||||||
|
self._scores[player] = self.state.getScoreNeural(self.universe.v_model, player)
|
||||||
|
else:
|
||||||
|
raise Exception('Unknown Score-Provider')
|
||||||
|
|
||||||
|
def getPriority(self):
|
||||||
|
return self.state.getPriority(self.getSelfScore(), self._cascadeMemory)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def playersNum(self):
|
||||||
|
return self.state.playersNum
|
||||||
|
|
||||||
|
@property
|
||||||
|
def avaibleActions(self):
|
||||||
|
r = []
|
||||||
|
for c in self.childs:
|
||||||
|
r.append(c.lastAction)
|
||||||
|
return r
|
||||||
|
|
||||||
|
@property
|
||||||
|
def curPlayer(self):
|
||||||
|
return self.state.curPlayer
|
||||||
|
|
||||||
|
def _getWinner(self):
|
||||||
|
return self.state.checkWin()
|
||||||
|
|
||||||
|
def getWinner(self):
|
||||||
|
if len(self.childs) == 0:
|
||||||
|
return -1
|
||||||
|
if self._winner==-2:
|
||||||
|
self._winner = self._getWinner()
|
||||||
|
return self._winner
|
||||||
|
|
||||||
|
def _activateEdge(self, dist=0):
|
||||||
|
if not self.strongScoresAvaible():
|
||||||
|
self.universe.newOpen(self)
|
||||||
|
else:
|
||||||
|
for c in self.childs:
|
||||||
|
if c._cascadeMemory > 0.001*(dist-2) or random.random() < 0.01:
|
||||||
|
c._activateEdge(dist=dist+1)
|
||||||
|
self.mark_update()
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
s = []
|
||||||
|
if self.lastAction == None:
|
||||||
|
s.append("[ {ROOT} ]")
|
||||||
|
else:
|
||||||
|
s.append("[ -> "+str(self.lastAction)+" ]")
|
||||||
|
s.append("[ turn: "+str(self.state.curPlayer)+" ]")
|
||||||
|
s.append(str(self.state))
|
||||||
|
s.append("[ score: "+str(self.getScoreFor(0))+" ]")
|
||||||
|
return '\n'.join(s)
|
||||||
|
248
vacuumDecay/games/chess_game.py
Normal file
248
vacuumDecay/games/chess_game.py
Normal file
@ -0,0 +1,248 @@
|
|||||||
|
import numpy as np
|
||||||
|
import torch as th
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from PIL import Image
|
||||||
|
import chess
|
||||||
|
import chess.svg
|
||||||
|
import io
|
||||||
|
|
||||||
|
from vacuumDecay import State, Action, Runtime, NeuralRuntime, Trainer, choose, main
|
||||||
|
|
||||||
|
class ChessAction(Action):
|
||||||
|
def __init__(self, player, data):
|
||||||
|
super().__init__(player, data)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return "<P"+str(self.player)+"-"+self.data.uci()+">"
|
||||||
|
|
||||||
|
def getImage(self, state=None):
|
||||||
|
return Image.open(io.BytesIO(chess.svg.board(board=state.board, format='png', squares=[self.data.from_square, self.data.to_square], arrows=[self.move])))
|
||||||
|
|
||||||
|
def getTensor(self, state):
|
||||||
|
board, additionals = state.getTensor()
|
||||||
|
|
||||||
|
tensor = np.zeros((8, 8), dtype=int) # 13 channels for piece types and move squares
|
||||||
|
|
||||||
|
# Mark the from_square and to_square
|
||||||
|
from_row, from_col = divmod(self.data.from_square, 8)
|
||||||
|
to_row, to_col = divmod(self.data.to_square, 8)
|
||||||
|
|
||||||
|
tensor[from_row, from_col] = 1 # Mark the "from" square
|
||||||
|
tensor[to_row, to_col] = 1 # Mark the "to" square
|
||||||
|
|
||||||
|
# Get the piece that was moved
|
||||||
|
pieceT = np.zeros((12), dtype=int) # 13 channels for piece types and move squares
|
||||||
|
piece = state.board.piece_at(self.data.from_square)
|
||||||
|
if piece:
|
||||||
|
piece_type = {
|
||||||
|
'p': 0, 'n': 1, 'b': 2, 'r': 3, 'q': 4, 'k': 5,
|
||||||
|
'P': 6, 'N': 7, 'B': 8, 'R': 9, 'Q': 10, 'K': 11
|
||||||
|
}
|
||||||
|
pieceT[piece_type[piece.symbol()]] = 1
|
||||||
|
|
||||||
|
# Flatten the tensor and return as a PyTorch tensor
|
||||||
|
return (board, additionals, th.concat(tensor.flatten(), pieceT.flatten()))
|
||||||
|
|
||||||
|
piece_values = {
|
||||||
|
chess.PAWN: 1,
|
||||||
|
chess.KNIGHT: 3,
|
||||||
|
chess.BISHOP: 3,
|
||||||
|
chess.ROOK: 5,
|
||||||
|
chess.QUEEN: 9
|
||||||
|
}
|
||||||
|
|
||||||
|
class ChessState(State):
|
||||||
|
def __init__(self, curPlayer=0, generation=0, board=None):
|
||||||
|
if type(board) == type(None):
|
||||||
|
board = chess.Board()
|
||||||
|
self.curPlayer = curPlayer
|
||||||
|
self.generation = generation
|
||||||
|
self.playersNum = 2
|
||||||
|
self.board = board
|
||||||
|
|
||||||
|
def mutate(self, action):
|
||||||
|
newBoard = self.board.copy()
|
||||||
|
newBoard.push(action.data)
|
||||||
|
return ChessState(curPlayer=(self.curPlayer+1)%2, board=newBoard)
|
||||||
|
|
||||||
|
# Function to calculate total value of pieces for a player
|
||||||
|
def calculate_piece_value(self, board, color):
|
||||||
|
value = 0
|
||||||
|
for square in chess.scan_reversed(board.occupied_co[color]):
|
||||||
|
piece = board.piece_at(square)
|
||||||
|
if piece is not None:
|
||||||
|
value += piece_values.get(piece.piece_type, 0)
|
||||||
|
return value
|
||||||
|
|
||||||
|
# Function to calculate winning probability for each player
|
||||||
|
def calculate_winning_probability(self):
|
||||||
|
white_piece_value = self.calculate_piece_value(self.board, chess.WHITE)
|
||||||
|
black_piece_value = self.calculate_piece_value(self.board, chess.BLACK)
|
||||||
|
total_piece_value = white_piece_value + black_piece_value
|
||||||
|
|
||||||
|
# Calculate winning probabilities
|
||||||
|
white_probability = white_piece_value / total_piece_value
|
||||||
|
black_probability = black_piece_value / total_piece_value
|
||||||
|
|
||||||
|
return white_probability, black_probability
|
||||||
|
|
||||||
|
def getScoreFor(self, player):
|
||||||
|
w = self.checkWin()
|
||||||
|
if w == None:
|
||||||
|
return self.calculate_winning_probability()[player]
|
||||||
|
if w == player:
|
||||||
|
return 1
|
||||||
|
if w == -1:
|
||||||
|
return 0.1
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def getAvaibleActions(self):
|
||||||
|
for move in self.board.legal_moves:
|
||||||
|
yield ChessAction(self.curPlayer, move)
|
||||||
|
|
||||||
|
def checkWin(self):
|
||||||
|
if self.board.is_checkmate():
|
||||||
|
return (self.curPlayer+1)%2
|
||||||
|
elif self.board.is_stalemate():
|
||||||
|
return -1
|
||||||
|
return None
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return str(self.board)
|
||||||
|
|
||||||
|
def getTensor(self):
|
||||||
|
board = self.board
|
||||||
|
piece_to_plane = {
|
||||||
|
'P': 0, 'N': 1, 'B': 2, 'R': 3, 'Q': 4, 'K': 5,
|
||||||
|
'p': 6, 'n': 7, 'b': 8, 'r': 9, 'q': 10, 'k': 11
|
||||||
|
}
|
||||||
|
|
||||||
|
tensor = np.zeros((12, 8, 8), dtype=int)
|
||||||
|
for square in chess.SQUARES:
|
||||||
|
piece = board.piece_at(square)
|
||||||
|
if piece:
|
||||||
|
plane = piece_to_plane[piece.symbol()]
|
||||||
|
row, col = divmod(square, 8)
|
||||||
|
tensor[plane, row, col] = 1
|
||||||
|
|
||||||
|
# Side to move
|
||||||
|
side_to_move = np.array([1 if board.turn == chess.WHITE else 0])
|
||||||
|
|
||||||
|
# Castling rights
|
||||||
|
castling_rights = np.array([
|
||||||
|
1 if board.has_kingside_castling_rights(chess.WHITE) else 0,
|
||||||
|
1 if board.has_queenside_castling_rights(chess.WHITE) else 0,
|
||||||
|
1 if board.has_kingside_castling_rights(chess.BLACK) else 0,
|
||||||
|
1 if board.has_queenside_castling_rights(chess.BLACK) else 0
|
||||||
|
])
|
||||||
|
|
||||||
|
# En passant target square
|
||||||
|
en_passant = np.zeros((8, 8), dtype=int)
|
||||||
|
if board.ep_square:
|
||||||
|
row, col = divmod(board.ep_square, 8)
|
||||||
|
en_passant[row, col] = 1
|
||||||
|
|
||||||
|
# Half-move clock and full-move number
|
||||||
|
half_move_clock = np.array([board.halfmove_clock])
|
||||||
|
full_move_number = np.array([board.fullmove_number])
|
||||||
|
|
||||||
|
additionals = np.concatenate([
|
||||||
|
side_to_move,
|
||||||
|
castling_rights,
|
||||||
|
en_passant.flatten(),
|
||||||
|
half_move_clock,
|
||||||
|
full_move_number
|
||||||
|
])
|
||||||
|
|
||||||
|
return (th.tensor(tensor), th.tensor(additionals))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def getVModel():
|
||||||
|
return ChessV()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def getQModel():
|
||||||
|
return ChessQ()
|
||||||
|
|
||||||
|
def getImage(self):
|
||||||
|
return Image.open(io.BytesIO(chess.svg.board(board=self.board, format='png')))
|
||||||
|
|
||||||
|
class ChessV(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
# CNN for the board tensor
|
||||||
|
self.conv1 = nn.Conv2d(12, 16, kernel_size=3, padding=1)
|
||||||
|
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
|
||||||
|
self.fc1 = nn.Linear(32 * 8 * 8, 256)
|
||||||
|
|
||||||
|
# FCNN for the board tensor
|
||||||
|
self.fc2 = nn.Linear(8 * 8, 64)
|
||||||
|
|
||||||
|
# FCNN for additional info
|
||||||
|
self.fc_additional1 = nn.Linear(71, 64)
|
||||||
|
|
||||||
|
# Combine all outputs
|
||||||
|
self.fc_combined1 = nn.Linear(256 + 64 + 64, 128)
|
||||||
|
self.fc_combined2 = nn.Linear(128, 1)
|
||||||
|
|
||||||
|
def forward(self, inp):
|
||||||
|
board_tensor, additional_info = inp
|
||||||
|
# Process the board tensor through the CNN
|
||||||
|
x = F.relu(self.conv1(board_tensor))
|
||||||
|
x = F.relu(self.conv2(x))
|
||||||
|
x = x.view(x.size(0), -1) # Flatten the tensor
|
||||||
|
x = F.relu(self.fc1(x))
|
||||||
|
|
||||||
|
y = F.relu(self.fc2(board_tensor.view(board_tensor.size(0), -1)))
|
||||||
|
|
||||||
|
# Process the additional info through the FCNN
|
||||||
|
z = F.relu(self.fc_additional1(additional_info))
|
||||||
|
|
||||||
|
# Combine the outputs
|
||||||
|
combined = th.cat((x, y, z), dim=1)
|
||||||
|
combined = F.relu(self.fc_combined1(combined))
|
||||||
|
logit = self.fc_combined2(combined)
|
||||||
|
|
||||||
|
return logit
|
||||||
|
|
||||||
|
class ChessQ(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
# CNN for the board tensor
|
||||||
|
self.conv1 = nn.Conv2d(12, 16, kernel_size=3, padding=1)
|
||||||
|
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
|
||||||
|
self.fc1 = nn.Linear(32 * 8 * 8, 256)
|
||||||
|
|
||||||
|
# FCNN for the board tensor
|
||||||
|
self.fc2 = nn.Linear(8 * 8, 64)
|
||||||
|
|
||||||
|
# FCNN for additional info
|
||||||
|
self.fc_additional1 = nn.Linear(71, 64)
|
||||||
|
|
||||||
|
# Combine all outputs
|
||||||
|
self.fc_combined1 = nn.Linear(256 + 64 + 64, 128)
|
||||||
|
self.fc_combined2 = nn.Linear(128, 1)
|
||||||
|
|
||||||
|
def forward(self, inp):
|
||||||
|
board_tensor, additional_info, action = inp
|
||||||
|
# Process the board tensor through the CNN
|
||||||
|
x = F.relu(self.conv1(board_tensor))
|
||||||
|
x = F.relu(self.conv2(x))
|
||||||
|
x = x.view(x.size(0), -1) # Flatten the tensor
|
||||||
|
x = F.relu(self.fc1(x))
|
||||||
|
|
||||||
|
y = F.relu(self.fc2(board_tensor.view(board_tensor.size(0), -1)))
|
||||||
|
|
||||||
|
# Process the additional info through the FCNN
|
||||||
|
z = F.relu(self.fc_additional1(additional_info))
|
||||||
|
|
||||||
|
# Combine the outputs
|
||||||
|
combined = th.cat((x, y, z), dim=1)
|
||||||
|
combined = F.relu(self.fc_combined1(combined))
|
||||||
|
logit = self.fc_combined2(combined)
|
||||||
|
|
||||||
|
return logit
|
||||||
|
|
||||||
|
if __name__=="__main__":
|
||||||
|
main(ChessState, start_visualizer=False)
|
@ -25,6 +25,9 @@ class TTTAction(Action):
|
|||||||
draw.line((x+40, y-40, x-40, y+40), fill='red', width=2)
|
draw.line((x+40, y-40, x-40, y+40), fill='red', width=2)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
|
def getTensor(self, state, player=None):
|
||||||
|
return torch.concat(torch.tensor([self.turn]), torch.tensor(state.board), torch.tensor(state.mutate(self).board))
|
||||||
|
|
||||||
class TTTState(State):
|
class TTTState(State):
|
||||||
def __init__(self, curPlayer=0, generation=0, playersNum=2, board=None):
|
def __init__(self, curPlayer=0, generation=0, playersNum=2, board=None):
|
||||||
if type(board) == type(None):
|
if type(board) == type(None):
|
||||||
@ -66,17 +69,29 @@ class TTTState(State):
|
|||||||
s.append(" ".join([str(p) if p!=None else '.' for p in self.board[l*3:][:3]]))
|
s.append(" ".join([str(p) if p!=None else '.' for p in self.board[l*3:][:3]]))
|
||||||
return "\n".join(s)
|
return "\n".join(s)
|
||||||
|
|
||||||
def getTensor(self):
|
def getTensor(self, player=None):
|
||||||
return torch.tensor([self.turn] + self.board)
|
return torch.concat(torch.tensor([self.curPlayer]), torch.tensor(self.board))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def getModel():
|
def getVModel(cls):
|
||||||
return torch.nn.Sequential(
|
return torch.nn.Sequential(
|
||||||
torch.nn.Linear(10, 10),
|
torch.nn.Linear(10, 10),
|
||||||
torch.nn.ReLu(),
|
torch.nn.ReLU(),
|
||||||
torch.nn.Linear(10, 3),
|
torch.nn.Linear(10, 3),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.Linear(3,1),
|
||||||
|
torch.nn.Sigmoid(),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def getQModel(cls):
|
||||||
|
return torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(20, 12),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.Linear(12, 3),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.Linear(3,1),
|
||||||
torch.nn.Sigmoid(),
|
torch.nn.Sigmoid(),
|
||||||
torch.nn.Linear(3,1)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def getImage(self):
|
def getImage(self):
|
||||||
@ -98,4 +113,4 @@ class TTTState(State):
|
|||||||
return img
|
return img
|
||||||
|
|
||||||
if __name__=="__main__":
|
if __name__=="__main__":
|
||||||
main(TTTState)
|
main(TTTState, start_visualizer=False)
|
@ -3,7 +3,7 @@ A lot of this code was stolen from Pulkit Maloo (https://github.com/pulkitmaloo/
|
|||||||
"""
|
"""
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from troch import nn
|
from torch import nn
|
||||||
from PIL import Image, ImageDraw
|
from PIL import Image, ImageDraw
|
||||||
|
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
@ -11,8 +11,11 @@ import itertools
|
|||||||
|
|
||||||
from vacuumDecay import State, Action, Runtime, NeuralRuntime, Trainer, choose, main
|
from vacuumDecay import State, Action, Runtime, NeuralRuntime, Trainer, choose, main
|
||||||
|
|
||||||
|
class UTTTAction(Action):
|
||||||
|
def __init__(self, player, data):
|
||||||
|
super().__init__(player, data)
|
||||||
|
|
||||||
class TTTState(State):
|
class UTTTState(State):
|
||||||
def __init__(self, curPlayer=0, generation=0, playersNum=2, board=None, lastMove=-1):
|
def __init__(self, curPlayer=0, generation=0, playersNum=2, board=None, lastMove=-1):
|
||||||
if type(board) == type(None):
|
if type(board) == type(None):
|
||||||
board = "." * 81
|
board = "." * 81
|
||||||
@ -48,7 +51,7 @@ class TTTState(State):
|
|||||||
def mutate(self, action):
|
def mutate(self, action):
|
||||||
newBoard = self.board[:action.data] + ['O',
|
newBoard = self.board[:action.data] + ['O',
|
||||||
'X'][self.curPlayer] + self.board[action.data+1:]
|
'X'][self.curPlayer] + self.board[action.data+1:]
|
||||||
return TTTState(curPlayer=(self.curPlayer+1) % self.playersNum, playersNum=self.playersNum, board=newBoard, lastMove=action.data)
|
return UTTTState(curPlayer=(self.curPlayer+1) % self.playersNum, playersNum=self.playersNum, board=newBoard, lastMove=action.data)
|
||||||
|
|
||||||
def box(self, x, y):
|
def box(self, x, y):
|
||||||
return self.index(x, y) // 9
|
return self.index(x, y) // 9
|
||||||
@ -67,7 +70,7 @@ class TTTState(State):
|
|||||||
def getAvaibleActions(self):
|
def getAvaibleActions(self):
|
||||||
if self.last_move == -1:
|
if self.last_move == -1:
|
||||||
for i in range(9*9):
|
for i in range(9*9):
|
||||||
yield Action(self.curPlayer, i)
|
yield UTTTAction(self.curPlayer, i)
|
||||||
return
|
return
|
||||||
|
|
||||||
box_to_play = self.next_box(self.last_move)
|
box_to_play = self.next_box(self.last_move)
|
||||||
@ -83,19 +86,6 @@ class TTTState(State):
|
|||||||
if self.board[ind] == '.':
|
if self.board[ind] == '.':
|
||||||
yield Action(self.curPlayer, ind)
|
yield Action(self.curPlayer, ind)
|
||||||
|
|
||||||
# def getScoreFor(self, player):
|
|
||||||
# p = ['O','X'][player]
|
|
||||||
# sco = 5
|
|
||||||
# for w in self.box_won:
|
|
||||||
# if w==p:
|
|
||||||
# sco += 1
|
|
||||||
# elif w!='.':
|
|
||||||
# sco -= 0.5
|
|
||||||
# return 1/sco
|
|
||||||
|
|
||||||
# def getPriority(self, score, cascadeMem):
|
|
||||||
# return -cascadeMem*1 + 100
|
|
||||||
|
|
||||||
def checkWin(self):
|
def checkWin(self):
|
||||||
self.update_box_won()
|
self.update_box_won()
|
||||||
game_won = self.check_small_box(self.box_won)
|
game_won = self.check_small_box(self.box_won)
|
||||||
@ -147,11 +137,15 @@ class TTTState(State):
|
|||||||
return torch.tensor([self.symbToNum(b) for b in s])
|
return torch.tensor([self.symbToNum(b) for b in s])
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def getModel(cls, phase='default'):
|
def getVModel(cls, phase='default'):
|
||||||
return Model()
|
return TTTV()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def getQModel(cls, phase='default'):
|
||||||
|
return TTTQ()
|
||||||
|
|
||||||
|
|
||||||
class Model(nn.Module):
|
class TTTV(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -183,13 +177,6 @@ class Model(nn.Module):
|
|||||||
nn.Linear(self.chansPerSlot*9, self.chansComp),
|
nn.Linear(self.chansPerSlot*9, self.chansComp),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Linear(self.chansComp, 1),
|
nn.Linear(self.chansComp, 1),
|
||||||
#nn.Linear(9*8, 32),
|
|
||||||
# nn.ReLU(),
|
|
||||||
#nn.Linear(32, 8),
|
|
||||||
# nn.ReLU(),
|
|
||||||
#nn.Linear(16*9, 12),
|
|
||||||
# nn.ReLU(),
|
|
||||||
#nn.Linear(12, 1),
|
|
||||||
nn.Sigmoid()
|
nn.Sigmoid()
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -202,5 +189,54 @@ class Model(nn.Module):
|
|||||||
y = self.out(x)
|
y = self.out(x)
|
||||||
return y
|
return y
|
||||||
|
|
||||||
|
class TTTQ(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.chansPerSmol = 24
|
||||||
|
self.chansPerSlot = 8
|
||||||
|
self.chansComp = 8
|
||||||
|
|
||||||
|
self.smol = nn.Sequential(
|
||||||
|
nn.Conv2d(
|
||||||
|
in_channels=2,
|
||||||
|
out_channels=self.chansPerSmol,
|
||||||
|
kernel_size=(3, 3),
|
||||||
|
stride=3,
|
||||||
|
padding=0,
|
||||||
|
),
|
||||||
|
nn.ReLU()
|
||||||
|
)
|
||||||
|
self.comb = nn.Sequential(
|
||||||
|
nn.Conv1d(
|
||||||
|
in_channels=self.chansPerSmol,
|
||||||
|
out_channels=self.chansPerSlot,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
),
|
||||||
|
nn.ReLU()
|
||||||
|
)
|
||||||
|
self.out = nn.Sequential(
|
||||||
|
nn.Linear(self.chansPerSlot*9*2, self.chansComp),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(self.chansComp, 4),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(4, 1),
|
||||||
|
nn.Sigmoid()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
a, b = x
|
||||||
|
a = torch.reshape(a, (1, 9, 9))
|
||||||
|
b = torch.reshape(b, (1, 9, 9))
|
||||||
|
x = torch.stack((a,b))
|
||||||
|
x = self.smol(x)
|
||||||
|
x = torch.reshape(x, (self.chansPerSmol, 9))
|
||||||
|
x = self.comb(x)
|
||||||
|
x = torch.reshape(x, (-1,))
|
||||||
|
y = self.out(x)
|
||||||
|
return y
|
||||||
|
|
||||||
if __name__=="__main__":
|
if __name__=="__main__":
|
||||||
main(TTTState)
|
main(UTTTState)
|
@ -1,204 +0,0 @@
|
|||||||
class Node:
|
|
||||||
def __init__(self, state, universe=None, parent=None, lastAction=None):
|
|
||||||
self.state = state
|
|
||||||
if universe == None:
|
|
||||||
print('[!] No Universe defined. Spawning one...')
|
|
||||||
universe = Universe()
|
|
||||||
self.universe = universe
|
|
||||||
self.parent = parent
|
|
||||||
self.lastAction = lastAction
|
|
||||||
|
|
||||||
self._childs = None
|
|
||||||
self._scores = [None]*self.state.playersNum
|
|
||||||
self._strongs = [None]*self.state.playersNum
|
|
||||||
self._alive = True
|
|
||||||
self._cascadeMemory = 0 # Used for our alternative to alpha-beta pruning
|
|
||||||
|
|
||||||
self.last_updated = time.time() # New attribute
|
|
||||||
|
|
||||||
def update(self):
|
|
||||||
self.last_updated = time.time()
|
|
||||||
if hasattr(self.universe, 'visualizer'):
|
|
||||||
self.universe.visualizer.send_update()
|
|
||||||
|
|
||||||
def kill(self):
|
|
||||||
self._alive = False
|
|
||||||
|
|
||||||
def revive(self):
|
|
||||||
self._alive = True
|
|
||||||
|
|
||||||
@property
|
|
||||||
def childs(self):
|
|
||||||
if self._childs == None:
|
|
||||||
self._expand()
|
|
||||||
return self._childs
|
|
||||||
|
|
||||||
def _expand(self):
|
|
||||||
self._childs = []
|
|
||||||
actions = self.state.getAvaibleActions()
|
|
||||||
for action in actions:
|
|
||||||
newNode = Node(self.state.mutate(action),
|
|
||||||
self.universe, self, action)
|
|
||||||
self._childs.append(self.universe.merge(newNode))
|
|
||||||
self.update()
|
|
||||||
|
|
||||||
def getStrongFor(self, player):
|
|
||||||
if self._strongs[player] != None:
|
|
||||||
return self._strongs[player]
|
|
||||||
else:
|
|
||||||
return self.getScoreFor(player)
|
|
||||||
|
|
||||||
def _pullStrong(self):
|
|
||||||
strongs = [None]*self.playersNum
|
|
||||||
for p in range(self.playersNum):
|
|
||||||
cp = self.state.curPlayer
|
|
||||||
if cp == p:
|
|
||||||
best = float('inf')
|
|
||||||
for c in self.childs:
|
|
||||||
if c.getStrongFor(p) < best:
|
|
||||||
best = c.getStrongFor(p)
|
|
||||||
strongs[p] = best
|
|
||||||
else:
|
|
||||||
scos = [(c.getStrongFor(p), c.getStrongFor(cp)) for c in self.childs]
|
|
||||||
scos.sort(key=lambda x: x[1])
|
|
||||||
betterHalf = scos[:max(3, int(len(scos)/3))]
|
|
||||||
myScores = [bh[0]**2 for bh in betterHalf]
|
|
||||||
strongs[p] = sqrt(myScores[0]*0.75 + sum(myScores)/(len(myScores)*4))
|
|
||||||
update = False
|
|
||||||
for s in range(self.playersNum):
|
|
||||||
if strongs[s] != self._strongs[s]:
|
|
||||||
update = True
|
|
||||||
break
|
|
||||||
self._strongs = strongs
|
|
||||||
if update:
|
|
||||||
if self.parent != None:
|
|
||||||
cascade = self.parent._pullStrong()
|
|
||||||
else:
|
|
||||||
cascade = 2
|
|
||||||
self._cascadeMemory = self._cascadeMemory/2 + cascade
|
|
||||||
self.update()
|
|
||||||
return cascade + 1
|
|
||||||
self._cascadeMemory /= 2
|
|
||||||
return 0
|
|
||||||
|
|
||||||
def forceStrong(self, depth=3):
|
|
||||||
if depth == 0:
|
|
||||||
self.strongDecay()
|
|
||||||
else:
|
|
||||||
if len(self.childs):
|
|
||||||
for c in self.childs:
|
|
||||||
c.forceStrong(depth-1)
|
|
||||||
else:
|
|
||||||
self.strongDecay()
|
|
||||||
self.update()
|
|
||||||
|
|
||||||
def decayEvent(self):
|
|
||||||
for c in self.childs:
|
|
||||||
c.strongDecay()
|
|
||||||
self.update()
|
|
||||||
|
|
||||||
def strongDecay(self):
|
|
||||||
if self._strongs == [None]*self.playersNum:
|
|
||||||
if not self.scoresAvaible():
|
|
||||||
self._calcScores()
|
|
||||||
self._strongs = self._scores
|
|
||||||
if self.parent:
|
|
||||||
return self.parent._pullStrong()
|
|
||||||
return 1
|
|
||||||
return None
|
|
||||||
|
|
||||||
def getSelfScore(self):
|
|
||||||
return self.getScoreFor(self.curPlayer)
|
|
||||||
|
|
||||||
def getScoreFor(self, player):
|
|
||||||
if self._scores[player] == None:
|
|
||||||
self._calcScore(player)
|
|
||||||
self.update()
|
|
||||||
return self._scores[player]
|
|
||||||
|
|
||||||
def scoreAvaible(self, player):
|
|
||||||
return self._scores[player] != None
|
|
||||||
|
|
||||||
def scoresAvaible(self):
|
|
||||||
for p in self._scores:
|
|
||||||
if p == None:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def strongScoresAvaible(self):
|
|
||||||
for p in self._strongs:
|
|
||||||
if p == None:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def askUserForAction(self):
|
|
||||||
return self.state.askUserForAction(self.avaibleActions)
|
|
||||||
|
|
||||||
def _calcScores(self):
|
|
||||||
for p in range(self.state.playersNum):
|
|
||||||
self._calcScore(p)
|
|
||||||
|
|
||||||
def _calcScore(self, player):
|
|
||||||
winner = self._getWinner()
|
|
||||||
if winner != None:
|
|
||||||
if winner == player:
|
|
||||||
self._scores[player] = 0.0
|
|
||||||
elif winner == -1:
|
|
||||||
self._scores[player] = 2/3
|
|
||||||
else:
|
|
||||||
self._scores[player] = 1.0
|
|
||||||
self.update()
|
|
||||||
return
|
|
||||||
if self.universe.scoreProvider == 'naive':
|
|
||||||
self._scores[player] = self.state.getScoreFor(player)
|
|
||||||
elif self.universe.scoreProvider == 'neural':
|
|
||||||
self._scores[player] = self.state.getScoreNeural(self.universe.model, player)
|
|
||||||
else:
|
|
||||||
raise Exception('Unknown Score-Provider')
|
|
||||||
self.update()
|
|
||||||
|
|
||||||
def getPriority(self):
|
|
||||||
return self.state.getPriority(self.getSelfScore(), self._cascadeMemory)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def playersNum(self):
|
|
||||||
return self.state.playersNum
|
|
||||||
|
|
||||||
@property
|
|
||||||
def avaibleActions(self):
|
|
||||||
r = []
|
|
||||||
for c in self.childs:
|
|
||||||
r.append(c.lastAction)
|
|
||||||
return r
|
|
||||||
|
|
||||||
@property
|
|
||||||
def curPlayer(self):
|
|
||||||
return self.state.curPlayer
|
|
||||||
|
|
||||||
def _getWinner(self):
|
|
||||||
return self.state.checkWin()
|
|
||||||
|
|
||||||
def getWinner(self):
|
|
||||||
if len(self.childs) == 0:
|
|
||||||
return -1
|
|
||||||
return self._getWinner()
|
|
||||||
|
|
||||||
def _activateEdge(self, dist=0):
|
|
||||||
if not self.strongScoresAvaible():
|
|
||||||
self.universe.newOpen(self)
|
|
||||||
else:
|
|
||||||
for c in self.childs:
|
|
||||||
if c._cascadeMemory > 0.001*(dist-2) or random.random() < 0.01:
|
|
||||||
c._activateEdge(dist=dist+1)
|
|
||||||
self.update()
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
s = []
|
|
||||||
if self.lastAction == None:
|
|
||||||
s.append("[ {ROOT} ]")
|
|
||||||
else:
|
|
||||||
s.append("[ -> "+str(self.lastAction)+" ]")
|
|
||||||
s.append("[ turn: "+str(self.state.curPlayer)+" ]")
|
|
||||||
s.append(str(self.state))
|
|
||||||
s.append("[ score: "+str(self.getScoreFor(0))+" ]")
|
|
||||||
return '\n'.join(s)
|
|
@ -23,25 +23,25 @@ def aiVsAiLoop(StateClass, start_visualizer=False):
|
|||||||
trainer = Trainer(init, start_visualizer=start_visualizer)
|
trainer = Trainer(init, start_visualizer=start_visualizer)
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
def humanVsNaive(StateClass, start_visualizer=False):
|
def humanVsNaive(StateClass, start_visualizer=False, calcDepth=7):
|
||||||
run = Runtime(StateClass(), start_visualizer=start_visualizer)
|
run = Runtime(StateClass(), start_visualizer=start_visualizer)
|
||||||
run.game()
|
run.game(calcDepth=calcDepth)
|
||||||
|
|
||||||
def main(StateClass):
|
def main(StateClass, **kwargs):
|
||||||
options = ['Play Against AI',
|
options = ['Play Against AI',
|
||||||
'Play Against AI (AI begins)', 'Play Against AI (Fast Play)', 'Playground', 'Let AI train', 'Play against Naive']
|
'Play Against AI (AI begins)', 'Play Against AI (Fast Play)', 'Playground', 'Let AI train', 'Play against Naive']
|
||||||
opt = choose('?', options)
|
opt = choose('?', options)
|
||||||
if opt == options[0]:
|
if opt == options[0]:
|
||||||
humanVsAi(StateClass)
|
humanVsAi(StateClass,**kwargs)
|
||||||
elif opt == options[1]:
|
elif opt == options[1]:
|
||||||
humanVsAi(StateClass, bots=[1, 0])
|
humanVsAi(StateClass, bots=[1, 0], **kwargs)
|
||||||
elif opt == options[2]:
|
elif opt == options[2]:
|
||||||
humanVsAi(StateClass, depth=2, noBg=True)
|
humanVsAi(StateClass, depth=2, noBg=True, **kwargs)
|
||||||
elif opt == options[3]:
|
elif opt == options[3]:
|
||||||
humanVsAi(StateClass, bots=[None, None])
|
humanVsAi(StateClass, bots=[None, None], **kwargs)
|
||||||
elif opt == options[4]:
|
elif opt == options[4]:
|
||||||
aiVsAiLoop(StateClass)
|
aiVsAiLoop(StateClass, **kwargs)
|
||||||
elif opt == options[5]:
|
elif opt == options[5]:
|
||||||
humanVsNaive(StateClass)
|
humanVsNaive(StateClass, **kwargs)
|
||||||
else:
|
else:
|
||||||
aiVsAiLoop(StateClass)
|
aiVsAiLoop(StateClass, **kwargs)
|
||||||
|
@ -43,14 +43,14 @@ class Runtime():
|
|||||||
def __init__(self, initState, start_visualizer=False):
|
def __init__(self, initState, start_visualizer=False):
|
||||||
universe = QueueingUniverse()
|
universe = QueueingUniverse()
|
||||||
self.head = Node(initState, universe=universe)
|
self.head = Node(initState, universe=universe)
|
||||||
|
self.root = self.head
|
||||||
_ = self.head.childs
|
_ = self.head.childs
|
||||||
universe.newOpen(self.head)
|
universe.newOpen(self.head)
|
||||||
self.visualizer = None
|
|
||||||
if start_visualizer:
|
if start_visualizer:
|
||||||
self.startVisualizer()
|
self.startVisualizer()
|
||||||
|
|
||||||
def startVisualizer(self):
|
def startVisualizer(self):
|
||||||
self.visualizer = Visualizer(self.head.universe)
|
self.visualizer = Visualizer(self)
|
||||||
self.visualizer.start()
|
self.visualizer.start()
|
||||||
|
|
||||||
def spawnWorker(self):
|
def spawnWorker(self):
|
||||||
@ -85,11 +85,11 @@ class Runtime():
|
|||||||
self.head.forceStrong(calcDepth)
|
self.head.forceStrong(calcDepth)
|
||||||
opts = []
|
opts = []
|
||||||
for c in self.head.childs:
|
for c in self.head.childs:
|
||||||
opts.append((c, c.getStrongFor(self.head.curPlayer)))
|
opts.append((c, c.getStrongFor(self.head.curPlayer) + random.random()*0.000000001))
|
||||||
opts.sort(key=lambda x: x[1])
|
opts.sort(key=lambda x: x[1], reverse=True)
|
||||||
print('[i] Evaluated Options:')
|
print('[i] Evaluated Options:')
|
||||||
for o in opts:
|
for o in opts:
|
||||||
print('[ ]' + str(o[0].lastAction) + " (Score: "+str(o[1])+")")
|
print('[ ]' + str(o[0].lastAction) + " (Win prob: "+str(int((o[1])*10000)/100)+"%)")
|
||||||
print('[#] I choose to play: ' + str(opts[0][0].lastAction))
|
print('[#] I choose to play: ' + str(opts[0][0].lastAction))
|
||||||
self.performAction(opts[0][0].lastAction)
|
self.performAction(opts[0][0].lastAction)
|
||||||
else:
|
else:
|
||||||
@ -107,22 +107,23 @@ class Runtime():
|
|||||||
if bg:
|
if bg:
|
||||||
self.killWorker()
|
self.killWorker()
|
||||||
|
|
||||||
def saveModel(self, model, gen):
|
def saveModel(self, v_model, q_model, gen):
|
||||||
dat = model.state_dict()
|
v_state = v_model.state_dict()
|
||||||
|
q_model = q_model.state_dict()
|
||||||
with open(self.getModelFileName(), 'wb') as f:
|
with open(self.getModelFileName(), 'wb') as f:
|
||||||
pickle.dump((gen, dat), f)
|
pickle.dump((gen, v_state, q_model), f)
|
||||||
|
|
||||||
def loadModelState(self, model):
|
def loadModelState(self, v_model, q_model):
|
||||||
with open(self.getModelFileName(), 'rb') as f:
|
with open(self.getModelFileName(), 'rb') as f:
|
||||||
gen, dat = pickle.load(f)
|
gen, v_state, q_state = pickle.load(f)
|
||||||
model.load_state_dict(dat)
|
v_model.load_state_dict(v_state)
|
||||||
model.eval()
|
q_model.load_state_dict(q_state)
|
||||||
return gen
|
return gen
|
||||||
|
|
||||||
def loadModel(self):
|
def loadModel(self):
|
||||||
model = self.head.state.getModel()
|
v_model, q_model = self.head.state.getVModel(), self.head.state.getQModel()
|
||||||
gen = self.loadModelState(model)
|
gen = self.loadModelState(v_model, q_model)
|
||||||
return model, gen
|
return v_model, q_model, gen
|
||||||
|
|
||||||
def getModelFileName(self):
|
def getModelFileName(self):
|
||||||
return 'brains/uttt.vac'
|
return 'brains/uttt.vac'
|
||||||
@ -136,27 +137,29 @@ class NeuralRuntime(Runtime):
|
|||||||
def __init__(self, initState, **kwargs):
|
def __init__(self, initState, **kwargs):
|
||||||
super().__init__(initState, **kwargs)
|
super().__init__(initState, **kwargs)
|
||||||
|
|
||||||
model, gen = self.loadModel()
|
v_model, q_model, gen = self.loadModel()
|
||||||
|
|
||||||
self.head.universe.model = model
|
self.head.universe.v_model = v_model
|
||||||
|
self.head.universe.q_model = q_model
|
||||||
self.head.universe.scoreProvider = 'neural'
|
self.head.universe.scoreProvider = 'neural'
|
||||||
|
|
||||||
class Trainer(Runtime):
|
class Trainer(Runtime):
|
||||||
def __init__(self, initState, **kwargs):
|
def __init__(self, initState, **kwargs):
|
||||||
super().__init__(initState, **kwargs)
|
super().__init__(initState, **kwargs)
|
||||||
#self.universe = Universe()
|
|
||||||
self.universe = self.head.universe
|
self.universe = self.head.universe
|
||||||
self.rootNode = self.head
|
self.rootNode = self.head
|
||||||
self.terminal = None
|
self.terminal = None
|
||||||
|
|
||||||
def buildDatasetFromModel(self, model, depth=4, refining=True, fanOut=[5, 5, 5, 5, 4, 4, 4, 4], uncertainSec=15, exacity=5):
|
def buildDatasetFromModel(self, v_model, q_model, depth=4, refining=True, fanOut=[5, 5, 5, 5, 4, 4, 4, 4], uncertainSec=15, exacity=5):
|
||||||
print('[*] Building Timeline')
|
print('[*] Building Timeline')
|
||||||
term = self.linearPlay(model, calcDepth=depth, exacity=exacity)
|
term = self.linearPlay(v_model, q_model, calcDepth=depth, exacity=exacity)
|
||||||
if refining:
|
if refining:
|
||||||
print('[*] Refining Timeline (exploring alternative endings)')
|
print('[*] Refining Timeline (exploring alternative endings)')
|
||||||
cur = term
|
cur = term
|
||||||
for d in fanOut:
|
for d in fanOut:
|
||||||
cur = cur.parent
|
cur = cur.parent
|
||||||
|
if cur == None:
|
||||||
|
break
|
||||||
cur.forceStrong(d)
|
cur.forceStrong(d)
|
||||||
print('.', end='', flush=True)
|
print('.', end='', flush=True)
|
||||||
print('')
|
print('')
|
||||||
@ -164,9 +167,10 @@ class Trainer(Runtime):
|
|||||||
self.timelineExpandUncertain(term, uncertainSec)
|
self.timelineExpandUncertain(term, uncertainSec)
|
||||||
return term
|
return term
|
||||||
|
|
||||||
def linearPlay(self, model, calcDepth=7, exacity=5, verbose=False, firstNRandom=2):
|
def linearPlay(self, v_model, q_model, calcDepth=7, exacity=5, verbose=False, firstNRandom=2):
|
||||||
head = self.rootNode
|
head = self.rootNode
|
||||||
self.universe.model = model
|
self.universe.v_model = v_model
|
||||||
|
self.universe.q_model = q_model
|
||||||
self.spawnWorker()
|
self.spawnWorker()
|
||||||
while head.getWinner() == None:
|
while head.getWinner() == None:
|
||||||
if verbose:
|
if verbose:
|
||||||
@ -183,7 +187,7 @@ class Trainer(Runtime):
|
|||||||
firstNRandom -= 1
|
firstNRandom -= 1
|
||||||
ind = int(random.random()*len(opts))
|
ind = int(random.random()*len(opts))
|
||||||
else:
|
else:
|
||||||
opts.sort(key=lambda x: x[1])
|
opts.sort(key=lambda x: x[1], reverse=True)
|
||||||
if exacity >= 10:
|
if exacity >= 10:
|
||||||
ind = 0
|
ind = 0
|
||||||
else:
|
else:
|
||||||
@ -236,31 +240,52 @@ class Trainer(Runtime):
|
|||||||
self.killWorker()
|
self.killWorker()
|
||||||
print('')
|
print('')
|
||||||
|
|
||||||
def trainModel(self, model, lr=0.00005, cut=0.01, calcDepth=4, exacity=5, terms=None, batch=16):
|
def trainModel(self, v_model, q_model, lr=0.00005, cut=0.01, calcDepth=4, exacity=5, terms=None, batch=2):
|
||||||
loss_func = nn.MSELoss()
|
loss_func = nn.MSELoss()
|
||||||
optimizer = optim.Adam(model.parameters(), lr)
|
v_optimizer = optim.Adam(v_model.parameters(), lr)
|
||||||
|
q_optimizer = optim.Adam(q_model.parameters(), lr)
|
||||||
|
print('[*] Conditioning Brain')
|
||||||
if terms == None:
|
if terms == None:
|
||||||
terms = []
|
terms = []
|
||||||
for i in range(batch):
|
for i in range(batch):
|
||||||
terms.append(self.buildDatasetFromModel(
|
terms.append(self.buildDatasetFromModel(
|
||||||
model, depth=calcDepth, exacity=exacity))
|
v_model, q_model, depth=calcDepth, exacity=exacity))
|
||||||
print('[*] Conditioning Brain')
|
for r in range(16):
|
||||||
for r in range(64):
|
|
||||||
loss_sum = 0
|
loss_sum = 0
|
||||||
lLoss = 0
|
lLoss = 0
|
||||||
zeroLen = 0
|
zeroLen = 0
|
||||||
for i, node in enumerate(self.timelineIter(terms)):
|
for i, node in enumerate(self.timelineIter(terms)):
|
||||||
for p in range(self.rootNode.playersNum):
|
for p in range(self.rootNode.playersNum):
|
||||||
inp = node.state.getTensor(player=p)
|
inp = node.state.getTensor(player=p)
|
||||||
gol = torch.tensor(
|
v = torch.tensor(
|
||||||
[node.getStrongFor(p)], dtype=torch.float)
|
[node.getStrongFor(p)], dtype=torch.float)
|
||||||
out = model(inp)
|
qs = []
|
||||||
loss = loss_func(out, gol)
|
q_preds = []
|
||||||
optimizer.zero_grad()
|
q_loss = torch.Tensor([0])
|
||||||
loss.backward()
|
if node.childs:
|
||||||
optimizer.step()
|
for child in node.childs:
|
||||||
loss_sum += loss.item()
|
sa = child.lastAction.getTensor(node.state, player=p)
|
||||||
if loss.item() == 0.0:
|
q = child.getStrongFor(p)
|
||||||
|
q_pred = q_model(sa)
|
||||||
|
qs.append(q)
|
||||||
|
q_preds.append(q_pred)
|
||||||
|
qs = torch.Tensor(qs)
|
||||||
|
q_target = torch.zeros_like(qs).scatter_(0, torch.argmax(qs).unsqueeze(0), 1)
|
||||||
|
q_cur = torch.concat(q_preds)
|
||||||
|
q_loss = loss_func(q_cur, q_target)
|
||||||
|
q_optimizer.zero_grad()
|
||||||
|
q_loss.backward()
|
||||||
|
q_optimizer.step()
|
||||||
|
|
||||||
|
v_pred = v_model(inp)
|
||||||
|
v_loss = loss_func(v_pred, v)
|
||||||
|
v_optimizer.zero_grad()
|
||||||
|
v_loss.backward()
|
||||||
|
v_optimizer.step()
|
||||||
|
|
||||||
|
loss = v_loss.item() + q_loss.item()
|
||||||
|
loss_sum += loss
|
||||||
|
if v_loss.item() == 0.0:
|
||||||
zeroLen += 1
|
zeroLen += 1
|
||||||
if zeroLen == 5:
|
if zeroLen == 5:
|
||||||
break
|
break
|
||||||
@ -270,31 +295,31 @@ class Trainer(Runtime):
|
|||||||
lLoss = loss_sum
|
lLoss = loss_sum
|
||||||
return loss_sum
|
return loss_sum
|
||||||
|
|
||||||
def main(self, model=None, gens=1024, startGen=0):
|
def main(self, v_model=None, q_model=None, gens=1024, startGen=0):
|
||||||
newModel = False
|
newModel = False
|
||||||
if model == None:
|
if v_model == None or q_model==None:
|
||||||
print('[!] No brain found. Creating new one...')
|
print('[!] No brain found. Creating new one...')
|
||||||
newModel = True
|
newModel = True
|
||||||
model = self.rootNode.state.getModel()
|
v_model, q_model = self.rootNode.state.getVModel(), self.rootNode.state.getQModel()
|
||||||
self.universe.scoreProvider = ['neural', 'naive'][newModel]
|
self.universe.scoreProvider = ['neural', 'naive'][newModel]
|
||||||
model.train()
|
v_model.train(), q_model.train()
|
||||||
for gen in range(startGen, startGen+gens):
|
for gen in range(startGen, startGen+gens):
|
||||||
print('[#####] Gen '+str(gen)+' training:')
|
print('[#####] Gen '+str(gen)+' training:')
|
||||||
loss = self.trainModel(model, calcDepth=min(
|
loss = self.trainModel(v_model, q_model, calcDepth=min(
|
||||||
4, 3+int(gen/16)), exacity=int(gen/3+1), batch=4)
|
4, 3+int(gen/16)), exacity=int(gen/3+1), batch=4)
|
||||||
print('[L] '+str(loss))
|
print('[L] '+str(loss))
|
||||||
self.universe.scoreProvider = 'neural'
|
self.universe.scoreProvider = 'neural'
|
||||||
self.saveModel(model, gen)
|
self.saveModel(v_model, q_model, gen)
|
||||||
|
|
||||||
def trainFromTerm(self, term):
|
def trainFromTerm(self, term):
|
||||||
model, gen = self.loadModel()
|
v_model, q_model, gen = self.loadModel()
|
||||||
self.universe.scoreProvider = 'neural'
|
self.universe.scoreProvider = 'neural'
|
||||||
self.trainModel(model, calcDepth=4, exacity=10, term=term)
|
self.trainModel(v_model, q_model, calcDepth=4, exacity=10, term=term)
|
||||||
self.saveModel(model)
|
self.saveModel(v_model, q_model)
|
||||||
|
|
||||||
def train(self):
|
def train(self):
|
||||||
if os.path.exists(self.getModelFileName()):
|
if os.path.exists(self.getModelFileName()):
|
||||||
model, gen = self.loadModel()
|
v_model, q_model, gen = self.loadModel()
|
||||||
self.main(model, startGen=gen+1)
|
self.main(v_model, q_model, startGen=gen+1)
|
||||||
else:
|
else:
|
||||||
self.main()
|
self.main()
|
||||||
|
@ -2,70 +2,124 @@
|
|||||||
<html lang="en">
|
<html lang="en">
|
||||||
<head>
|
<head>
|
||||||
<meta charset="UTF-8">
|
<meta charset="UTF-8">
|
||||||
<title>Game Tree Visualization</title>
|
<title>Interactive Tree Visualization</title>
|
||||||
|
<script src="https://cdnjs.cloudflare.com/ajax/libs/socket.io/4.0.1/socket.io.js"></script>
|
||||||
<script src="https://d3js.org/d3.v5.min.js"></script>
|
<script src="https://d3js.org/d3.v5.min.js"></script>
|
||||||
<script src="//cdnjs.cloudflare.com/ajax/libs/socket.io/2.3.0/socket.io.js"></script>
|
<style>
|
||||||
|
.links line {
|
||||||
|
stroke: #999;
|
||||||
|
stroke-opacity: 0.6;
|
||||||
|
stroke-width: 1.5px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.nodes rect {
|
||||||
|
stroke: #fff;
|
||||||
|
stroke-width: 1.5px;
|
||||||
|
}
|
||||||
|
|
||||||
|
text {
|
||||||
|
font: 10px sans-serif;
|
||||||
|
pointer-events: none;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
</head>
|
</head>
|
||||||
<body>
|
<body>
|
||||||
<div id="graph"></div>
|
<div id="graph"></div>
|
||||||
<script>
|
<script>
|
||||||
var socket = io.connect('http://' + document.domain + ':' + location.port);
|
var socket = io();
|
||||||
|
|
||||||
|
var margin = {top: 20, right: 120, bottom: 20, left: 120},
|
||||||
|
width = 960 - margin.right - margin.left,
|
||||||
|
height = 800 - margin.top - margin.bottom;
|
||||||
|
|
||||||
var svg = d3.select("#graph").append("svg")
|
var svg = d3.select("#graph").append("svg")
|
||||||
.attr("width", window.innerWidth)
|
.attr("width", width + margin.right + margin.left)
|
||||||
.attr("height", window.innerHeight);
|
.attr("height", height + margin.top + margin.bottom)
|
||||||
|
.append("g")
|
||||||
|
.attr("transform", "translate(" + margin.left + "," + margin.top + ")");
|
||||||
|
|
||||||
var simulation = d3.forceSimulation()
|
var tree = d3.tree().size([height, width]);
|
||||||
.force("link", d3.forceLink().id(function(d) { return d.id; }).distance(100))
|
|
||||||
.force("charge", d3.forceManyBody().strength(-300))
|
|
||||||
.force("center", d3.forceCenter(window.innerWidth / 2, window.innerHeight / 2));
|
|
||||||
|
|
||||||
var link = svg.append("g")
|
var root;
|
||||||
.attr("class", "links")
|
|
||||||
.selectAll("line");
|
|
||||||
|
|
||||||
var node = svg.append("g")
|
|
||||||
.attr("class", "nodes")
|
|
||||||
.selectAll("circle");
|
|
||||||
|
|
||||||
socket.on('update', function(data) {
|
socket.on('update', function(data) {
|
||||||
var nodes = data.nodes;
|
console.log(data);
|
||||||
var edges = data.edges;
|
|
||||||
|
var stratify = d3.stratify()
|
||||||
|
.id(function(d) { return d.id; })
|
||||||
|
.parentId(function(d) { return d.parentId; });
|
||||||
|
|
||||||
|
try {
|
||||||
|
root = stratify(data.nodes);
|
||||||
|
} catch (e) {
|
||||||
|
console.error(e);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
tree(root);
|
||||||
|
|
||||||
|
var link = svg.selectAll(".link")
|
||||||
|
.data(root.links(), function(d) { return d.source.id + "-" + d.target.id; });
|
||||||
|
|
||||||
link = link.data(edges);
|
|
||||||
link.exit().remove();
|
link.exit().remove();
|
||||||
link = link.enter().append("line").merge(link);
|
|
||||||
|
|
||||||
node = node.data(nodes);
|
link.enter().append("path")
|
||||||
|
.attr("class", "link")
|
||||||
|
.merge(link)
|
||||||
|
.attr("d", d3.linkHorizontal()
|
||||||
|
.x(function(d) { return d.y; })
|
||||||
|
.y(function(d) { return d.x; }));
|
||||||
|
|
||||||
|
var node = svg.selectAll(".node")
|
||||||
|
.data(root.descendants(), function(d) { return d.id; });
|
||||||
|
|
||||||
node.exit().remove();
|
node.exit().remove();
|
||||||
node = node.enter().append("circle")
|
|
||||||
.attr("r", 5)
|
|
||||||
.attr("fill", function(d) {
|
|
||||||
var age = Date.now() - d.last_updated;
|
|
||||||
return d3.interpolateCool(Math.min(age / 10000, 1));
|
|
||||||
})
|
|
||||||
.merge(node);
|
|
||||||
|
|
||||||
simulation.nodes(nodes)
|
var nodeEnter = node.enter().append("g")
|
||||||
.on("tick", ticked);
|
.attr("class", "node")
|
||||||
|
.attr("transform", function(d) {
|
||||||
simulation.force("link")
|
return "translate(" + d.y + "," + d.x + ")";
|
||||||
.links(edges);
|
|
||||||
|
|
||||||
simulation.alpha(1).restart();
|
|
||||||
});
|
});
|
||||||
|
|
||||||
function ticked() {
|
nodeEnter.append("rect")
|
||||||
link
|
.attr("width", 40)
|
||||||
.attr("x1", function(d) { return d.source.x; })
|
.attr("height", 40)
|
||||||
.attr("y1", function(d) { return d.source.y; })
|
.attr("x", -20)
|
||||||
.attr("x2", function(d) { return d.target.x; })
|
.attr("y", -20)
|
||||||
.attr("y2", function(d) { return d.target.y; });
|
.attr("fill", function(d) {
|
||||||
|
var age = Date.now() - d.data.last_updated;
|
||||||
|
return d3.interpolateCool(Math.min(age / 10000, 1));
|
||||||
|
});
|
||||||
|
|
||||||
node
|
nodeEnter.append("image")
|
||||||
.attr("cx", function(d) { return d.x; })
|
.attr("xlink:href", function(d) { return d.data.image ? 'data:image/jpeg;base64,' + d.data.image : ''; })
|
||||||
.attr("cy", function(d) { return d.y; });
|
.attr("x", -20)
|
||||||
|
.attr("y", -20)
|
||||||
|
.attr("width", 40)
|
||||||
|
.attr("height", 40);
|
||||||
|
|
||||||
|
nodeEnter.append("text")
|
||||||
|
.attr("dy", -30)
|
||||||
|
.attr("dx", 0)
|
||||||
|
.text(function(d) { return "Player: " + (d.data.currentPlayer !== undefined ? d.data.currentPlayer : 'N/A'); });
|
||||||
|
|
||||||
|
nodeEnter.append("text")
|
||||||
|
.attr("dy", -15)
|
||||||
|
.attr("dx", 0)
|
||||||
|
.text(function(d) {
|
||||||
|
if (d.data.winProbs && d.data.winProbs.length >= 2) {
|
||||||
|
return "Win Probs: P0: " + d.data.winProbs[0].toFixed(2) + ", P1: " + d.data.winProbs[1].toFixed(2);
|
||||||
|
} else {
|
||||||
|
return "Win Probs: N/A";
|
||||||
}
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
node = nodeEnter.merge(node);
|
||||||
|
|
||||||
|
node.attr("transform", function(d) {
|
||||||
|
return "translate(" + d.y + "," + d.x + ")";
|
||||||
|
});
|
||||||
|
});
|
||||||
</script>
|
</script>
|
||||||
</body>
|
</body>
|
||||||
</html>
|
</html>
|
||||||
|
@ -3,14 +3,19 @@ import time
|
|||||||
import networkx as nx
|
import networkx as nx
|
||||||
from flask import Flask, render_template, jsonify
|
from flask import Flask, render_template, jsonify
|
||||||
from flask_socketio import SocketIO, emit
|
from flask_socketio import SocketIO, emit
|
||||||
|
from io import BytesIO
|
||||||
|
import base64
|
||||||
|
|
||||||
class Visualizer:
|
class Visualizer:
|
||||||
def __init__(self, universe):
|
def __init__(self, runtime):
|
||||||
self.universe = universe
|
self.runtime = runtime
|
||||||
self.graph = nx.DiGraph()
|
self.graph = nx.DiGraph()
|
||||||
self.app = Flask(__name__)
|
self.app = Flask(__name__)
|
||||||
self.socketio = SocketIO(self.app)
|
self.socketio = SocketIO(self.app)
|
||||||
self.init_flask()
|
self.init_flask()
|
||||||
|
self.update_thread = threading.Thread(target=self.update_periodically)
|
||||||
|
self.update_thread.daemon = True
|
||||||
|
self.update_thread.start()
|
||||||
|
|
||||||
def init_flask(self):
|
def init_flask(self):
|
||||||
@self.app.route('/')
|
@self.app.route('/')
|
||||||
@ -19,36 +24,19 @@ class Visualizer:
|
|||||||
|
|
||||||
@self.app.route('/data')
|
@self.app.route('/data')
|
||||||
def data():
|
def data():
|
||||||
nodes_data = []
|
return jsonify(self.get_data())
|
||||||
edges_data = []
|
|
||||||
for node in self.universe.iter():
|
|
||||||
nodes_data.append({
|
|
||||||
'id': id(node),
|
|
||||||
'image': node.state.getImage().tobytes() if node.state.getImage() else None,
|
|
||||||
'value': node.getScoreFor(node.state.curPlayer),
|
|
||||||
'last_updated': node.last_updated
|
|
||||||
})
|
|
||||||
for child in node.childs:
|
|
||||||
edges_data.append({'source': id(node), 'target': id(child)})
|
|
||||||
return jsonify(nodes=nodes_data, edges=edges_data)
|
|
||||||
|
|
||||||
@self.socketio.on('connect')
|
@self.socketio.on('connect')
|
||||||
def handle_connect():
|
def handle_connect():
|
||||||
print('Client connected')
|
print('Client connected')
|
||||||
|
|
||||||
def send_update(self):
|
def send_update(self):
|
||||||
nodes_data = []
|
self.socketio.emit('update', self.get_data())
|
||||||
edges_data = []
|
|
||||||
for node in self.universe.iter():
|
def update_periodically(self):
|
||||||
nodes_data.append({
|
while True:
|
||||||
'id': id(node),
|
self.send_update()
|
||||||
'image': node.state.getImage().tobytes() if node.state.getImage() else None,
|
time.sleep(1)
|
||||||
'value': node.getScoreFor(node.state.curPlayer),
|
|
||||||
'last_updated': node.last_updated
|
|
||||||
})
|
|
||||||
for child in node.childs:
|
|
||||||
edges_data.append({'source': id(node), 'target': id(child)})
|
|
||||||
self.socketio.emit('update', {'nodes': nodes_data, 'edges': edges_data})
|
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
self.socketio.run(self.app, debug=True, use_reloader=False)
|
self.socketio.run(self.app, debug=True, use_reloader=False)
|
||||||
@ -56,3 +44,33 @@ class Visualizer:
|
|||||||
def start(self):
|
def start(self):
|
||||||
self.thread = threading.Thread(target=self.run)
|
self.thread = threading.Thread(target=self.run)
|
||||||
self.thread.start()
|
self.thread.start()
|
||||||
|
|
||||||
|
def get_data(self):
|
||||||
|
nodes_data = []
|
||||||
|
edges_data = []
|
||||||
|
|
||||||
|
def add_node_data(node, depth=0):
|
||||||
|
img = None
|
||||||
|
if node.state.getImage(): # depth <= 2:
|
||||||
|
buffered = BytesIO()
|
||||||
|
node.state.getImage().save(buffered, format="JPEG")
|
||||||
|
img = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||||
|
|
||||||
|
nodes_data.append({
|
||||||
|
'id': id(node),
|
||||||
|
'parentId': id(node.parent) if node.parent else None,
|
||||||
|
'image': img,
|
||||||
|
'currentPlayer': node.state.curPlayer,
|
||||||
|
'winProbs': [node.getStrongFor(i) for i in range(node.state.playersNum)],
|
||||||
|
'last_updated': node.last_updated
|
||||||
|
})
|
||||||
|
|
||||||
|
for child in node.childs:
|
||||||
|
edges_data.append({'source': id(node), 'target': id(child)})
|
||||||
|
add_node_data(child, depth=depth + 1)
|
||||||
|
|
||||||
|
head_node = self.runtime.head
|
||||||
|
if head_node:
|
||||||
|
add_node_data(head_node)
|
||||||
|
|
||||||
|
return {'nodes': nodes_data, 'edges': edges_data}
|
||||||
|
Loading…
Reference in New Issue
Block a user