Many shashed commits
This commit is contained in:
parent
f210c4f670
commit
6262aea6f0
25
README.md
25
README.md
@ -2,21 +2,24 @@
|
||||
|
||||
Project vacuumDecay is a framework for building AIs for games.
|
||||
Avaible architectures are
|
||||
- those used in Deep Blue (mini-max / expecti-max)
|
||||
- advanced expecti-max exploration based on utility heuristics
|
||||
- those used in AlphaGo Zero (knowledge distilation using neural-networks)
|
||||
|
||||
- those used in Deep Blue (mini-max / expecti-max)
|
||||
- advanced expecti-max exploration based on utility heuristics
|
||||
- those used in AlphaGo Zero (knowledge distilation using neural-networks)
|
||||
|
||||
A new AI is created by subclassing the State-class and defining the following functionality (mycelia.py provies a template):
|
||||
- initialization (generating the gameboard or similar)
|
||||
- getting avaible actions for the current situation (returns an Action-object, which can be subclassed to add additional functionality)
|
||||
- applying an action (the state itself should be immutable, a new state should be returned)
|
||||
- checking for a winning-condition (should return None if game has not yet ended)
|
||||
- (optional) a getter for a string-representation of the current state
|
||||
- (optional) a heuristic for the winning-condition (greatly improves capability)
|
||||
- (optional) a getter for a tensor that describes the current game state (required for knowledge distilation)
|
||||
- (optional) interface to allow a human to select an action
|
||||
|
||||
- initialization (generating the gameboard or similar)
|
||||
- getting avaible actions for the current situation (returns an Action-object, which can be subclassed to add additional functionality)
|
||||
- applying an action (the state itself should be immutable, a new state should be returned)
|
||||
- checking for a winning-condition (should return None if game has not yet ended)
|
||||
- (optional) a getter for a string-representation of the current state
|
||||
- (optional) a heuristic for the winning-condition (greatly improves capability for expecti-max)
|
||||
- (optional) a getter for a tensor that describes the current game state (required for knowledge distilation)
|
||||
- (optional) interface to allow a human to select an action
|
||||
|
||||
### Current state of the project
|
||||
|
||||
The only thing that currently works is the AI for Ultimate TicTacToe.
|
||||
It uses a trained neural heuristic (neuristic)
|
||||
You can train it or play against it (will also train it) using 'python ultimatetictactoe.py'
|
||||
|
@ -7,6 +7,7 @@ name = "vacuumDecay"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"torch",
|
||||
"numpy",
|
||||
"flask",
|
||||
"flask-socketio",
|
||||
"networkx",
|
||||
|
@ -1,4 +1,4 @@
|
||||
from vacuumDecay.runtime import Runtime, NeuralRuntime, Trainer
|
||||
from vacuumDecay.base import Node, Action, Universe, QueueingUniverse
|
||||
from vacuumDecay.base import Node, State, Action, Universe, QueueingUniverse
|
||||
from vacuumDecay.utils import choose
|
||||
from vacuumDecay.run import main
|
@ -1,19 +1,24 @@
|
||||
import torch
|
||||
import time
|
||||
import random
|
||||
from math import sqrt
|
||||
from abc import ABC, abstractmethod
|
||||
from queue import PriorityQueue, Empty
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vacuumDecay.utils import choose
|
||||
|
||||
class Action():
|
||||
# Should hold the data representing an action
|
||||
# Actions are applied to a State in State.mutate
|
||||
|
||||
def __init__(self, player, data):
|
||||
self.player = player
|
||||
self.data = data
|
||||
|
||||
# ImproveMe
|
||||
def __eq__(self, other):
|
||||
# This should be implemented differently
|
||||
# Two actions of different generations will never be compared
|
||||
@ -21,23 +26,33 @@ class Action():
|
||||
return False
|
||||
return str(self.data) == str(other.data)
|
||||
|
||||
# ImproveMe
|
||||
def __str__(self):
|
||||
# should return visual representation of this action
|
||||
# should start with < and end with >
|
||||
return "<P"+str(self.player)+"-"+str(self.data)+">"
|
||||
|
||||
# ImproveMe
|
||||
def getImage(self, state):
|
||||
# Should return an image representation of this action given the current state
|
||||
# Return None if not implemented
|
||||
return None
|
||||
|
||||
# ImproveMe
|
||||
def getTensor(self, state, player=None):
|
||||
# Should return a complete description of the action (including previous state)
|
||||
# This default will work, but may be suboptimal...
|
||||
return (state.getTensor(), state.mutate(self).getTensor())
|
||||
|
||||
class State(ABC):
|
||||
# Hold a representation of the current game-state
|
||||
# Allows retriving avaible actions (getAvaibleActions) and applying them (mutate)
|
||||
# Mutations return a new State and should not have any effect on the current State
|
||||
# Allows checking itself for a win (checkWin) or scoring itself based on a simple heuristic (getScore)
|
||||
# The calculated score should be 0 when won; higher when in a worse state; highest for loosing
|
||||
# getPriority is used for prioritising certain Nodes / States when expanding / walking the tree
|
||||
# getPriority is used for prioritising certain Nodes / States when expanding / walking the tree (TODO: Remove)
|
||||
|
||||
# Abstract Methodas need to be overrieden, improveMe methods can be overrieden
|
||||
|
||||
def __init__(self, curPlayer=0, generation=0, playersNum=2):
|
||||
self.curPlayer = curPlayer
|
||||
@ -81,10 +96,10 @@ class State(ABC):
|
||||
if w == None:
|
||||
return 0.5
|
||||
if w == player:
|
||||
return 0
|
||||
return 1
|
||||
if w == -1:
|
||||
return 0.9
|
||||
return 1
|
||||
return 0.1
|
||||
return 0
|
||||
|
||||
@abstractmethod
|
||||
def __str__(self):
|
||||
@ -92,23 +107,40 @@ class State(ABC):
|
||||
return "[#]"
|
||||
|
||||
@abstractmethod
|
||||
def getTensor(self, player=None, phase='default'):
|
||||
def getTensor(self, player=None):
|
||||
if player == None:
|
||||
player = self.curPlayer
|
||||
return torch.tensor([0])
|
||||
|
||||
@classmethod
|
||||
def getModel(cls, phase='default'):
|
||||
def getVModel(cls):
|
||||
# input will be output from state.getTensor
|
||||
pass
|
||||
|
||||
def getScoreNeural(self, model, player=None, phase='default'):
|
||||
return model(self.getTensor(player=player, phase=phase)).item()
|
||||
#improveMe
|
||||
def getQModel(cls):
|
||||
# input will be output from action.getTensor
|
||||
return DefaultQ(cls.getVModel())
|
||||
|
||||
def getScoreNeural(self, model, player=None):
|
||||
return model(self.getTensor(player=player)).item()
|
||||
|
||||
# improveMe
|
||||
def getImage(self):
|
||||
# Should return an image representation of this state
|
||||
# Return None if not implemented
|
||||
return None
|
||||
|
||||
class DefaultQ(nn.Module):
|
||||
def __init__(self, vModel):
|
||||
super().__init__()
|
||||
self.V = vModel
|
||||
|
||||
def forward(self, inp):
|
||||
s, s_prime = inp
|
||||
v, v_prime = self.V(s), self.V(s_prime)
|
||||
return F.sigmoid(v_prime - v)
|
||||
|
||||
class Universe():
|
||||
def __init__(self):
|
||||
self.scoreProvider = 'naive'
|
||||
@ -160,3 +192,208 @@ class QueueingUniverse(Universe):
|
||||
|
||||
def activateEdge(self, head):
|
||||
head._activateEdge()
|
||||
|
||||
class Node:
|
||||
def __init__(self, state, universe=None, parent=None, lastAction=None):
|
||||
self.state = state
|
||||
if universe == None:
|
||||
print('[!] No Universe defined. Spawning one...')
|
||||
universe = Universe()
|
||||
self.universe = universe
|
||||
self.parent = parent
|
||||
self.lastAction = lastAction
|
||||
|
||||
self._childs = None
|
||||
self._scores = [None]*self.state.playersNum
|
||||
self._strongs = [None]*self.state.playersNum
|
||||
self._alive = True
|
||||
self._cascadeMemory = 0 # Used for our alternative to alpha-beta pruning
|
||||
self._winner = -2
|
||||
|
||||
self.leaf = True
|
||||
self.last_updated = time.time() # New attribute
|
||||
|
||||
def mark_update(self):
|
||||
self.last_updated = time.time()
|
||||
|
||||
def kill(self):
|
||||
self._alive = False
|
||||
|
||||
def revive(self):
|
||||
self._alive = True
|
||||
|
||||
@property
|
||||
def childs(self):
|
||||
if self._childs == None:
|
||||
self._expand()
|
||||
return self._childs
|
||||
|
||||
def _expand(self):
|
||||
self.leaf = False
|
||||
self._childs = []
|
||||
actions = self.state.getAvaibleActions()
|
||||
for action in actions:
|
||||
newNode = Node(self.state.mutate(action),
|
||||
self.universe, self, action)
|
||||
self._childs.append(self.universe.merge(newNode))
|
||||
self.mark_update()
|
||||
|
||||
def getStrongFor(self, player):
|
||||
if self._strongs[player] != None:
|
||||
return self._strongs[player]
|
||||
else:
|
||||
return self.getScoreFor(player)
|
||||
|
||||
def _pullStrong(self):
|
||||
strongs = [None]*self.playersNum
|
||||
has_winner = self.getWinner() != None
|
||||
for p in range(self.playersNum):
|
||||
cp = self.state.curPlayer
|
||||
if has_winner:
|
||||
strongs[p] = self.getScoreFor(p)
|
||||
elif cp == p:
|
||||
best = float('-inf')
|
||||
for c in self.childs:
|
||||
if c.getStrongFor(p) > best:
|
||||
best = c.getStrongFor(p)
|
||||
strongs[p] = best
|
||||
else:
|
||||
scos = [(c.getStrongFor(p), c.getStrongFor(cp)) for c in self.childs]
|
||||
scos.sort(key=lambda x: x[1], reverse=True)
|
||||
betterHalf = [sco for sco, osc in scos[:max(3, int(len(scos)/2))]]
|
||||
strongs[p] = betterHalf[0]*0.9 + sum(betterHalf)/(len(betterHalf))*0.1
|
||||
update = False
|
||||
for s in range(self.playersNum):
|
||||
if strongs[s] != self._strongs[s]:
|
||||
update = True
|
||||
break
|
||||
self._strongs = strongs
|
||||
if update:
|
||||
if self.parent != None:
|
||||
cascade = self.parent._pullStrong()
|
||||
else:
|
||||
cascade = 2
|
||||
self._cascadeMemory = self._cascadeMemory/2 + cascade
|
||||
self.mark_update()
|
||||
return cascade + 1
|
||||
self._cascadeMemory /= 2
|
||||
return 0
|
||||
|
||||
def forceStrong(self, depth=3):
|
||||
if depth == 0:
|
||||
self.strongDecay()
|
||||
else:
|
||||
if len(self.childs):
|
||||
for c in self.childs:
|
||||
c.forceStrong(depth-1)
|
||||
else:
|
||||
self.strongDecay()
|
||||
|
||||
def decayEvent(self):
|
||||
for c in self.childs:
|
||||
c.strongDecay()
|
||||
|
||||
def strongDecay(self):
|
||||
if self._strongs == [None]*self.playersNum:
|
||||
if not self.scoresAvaible():
|
||||
self._calcScores()
|
||||
self._strongs = self._scores
|
||||
if self.parent:
|
||||
return self.parent._pullStrong()
|
||||
return 1
|
||||
return None
|
||||
|
||||
def getSelfScore(self):
|
||||
return self.getScoreFor(self.curPlayer)
|
||||
|
||||
def getScoreFor(self, player):
|
||||
if self._scores[player] == None:
|
||||
self._calcScore(player)
|
||||
return self._scores[player]
|
||||
|
||||
def scoreAvaible(self, player):
|
||||
return self._scores[player] != None
|
||||
|
||||
def scoresAvaible(self):
|
||||
for p in self._scores:
|
||||
if p == None:
|
||||
return False
|
||||
return True
|
||||
|
||||
def strongScoresAvaible(self):
|
||||
for p in self._strongs:
|
||||
if p == None:
|
||||
return False
|
||||
return True
|
||||
|
||||
def askUserForAction(self):
|
||||
return self.state.askUserForAction(self.avaibleActions)
|
||||
|
||||
def _calcScores(self):
|
||||
for p in range(self.state.playersNum):
|
||||
self._calcScore(p)
|
||||
|
||||
def _calcScore(self, player):
|
||||
winner = self._getWinner()
|
||||
if winner != None:
|
||||
if winner == player:
|
||||
self._scores[player] = 1.0
|
||||
elif winner == -1:
|
||||
self._scores[player] = 0.1
|
||||
else:
|
||||
self._scores[player] = 0.0
|
||||
return
|
||||
if self.universe.scoreProvider == 'naive':
|
||||
self._scores[player] = self.state.getScoreFor(player)
|
||||
elif self.universe.scoreProvider == 'neural':
|
||||
self._scores[player] = self.state.getScoreNeural(self.universe.v_model, player)
|
||||
else:
|
||||
raise Exception('Unknown Score-Provider')
|
||||
|
||||
def getPriority(self):
|
||||
return self.state.getPriority(self.getSelfScore(), self._cascadeMemory)
|
||||
|
||||
@property
|
||||
def playersNum(self):
|
||||
return self.state.playersNum
|
||||
|
||||
@property
|
||||
def avaibleActions(self):
|
||||
r = []
|
||||
for c in self.childs:
|
||||
r.append(c.lastAction)
|
||||
return r
|
||||
|
||||
@property
|
||||
def curPlayer(self):
|
||||
return self.state.curPlayer
|
||||
|
||||
def _getWinner(self):
|
||||
return self.state.checkWin()
|
||||
|
||||
def getWinner(self):
|
||||
if len(self.childs) == 0:
|
||||
return -1
|
||||
if self._winner==-2:
|
||||
self._winner = self._getWinner()
|
||||
return self._winner
|
||||
|
||||
def _activateEdge(self, dist=0):
|
||||
if not self.strongScoresAvaible():
|
||||
self.universe.newOpen(self)
|
||||
else:
|
||||
for c in self.childs:
|
||||
if c._cascadeMemory > 0.001*(dist-2) or random.random() < 0.01:
|
||||
c._activateEdge(dist=dist+1)
|
||||
self.mark_update()
|
||||
|
||||
def __str__(self):
|
||||
s = []
|
||||
if self.lastAction == None:
|
||||
s.append("[ {ROOT} ]")
|
||||
else:
|
||||
s.append("[ -> "+str(self.lastAction)+" ]")
|
||||
s.append("[ turn: "+str(self.state.curPlayer)+" ]")
|
||||
s.append(str(self.state))
|
||||
s.append("[ score: "+str(self.getScoreFor(0))+" ]")
|
||||
return '\n'.join(s)
|
||||
|
248
vacuumDecay/games/chess_game.py
Normal file
248
vacuumDecay/games/chess_game.py
Normal file
@ -0,0 +1,248 @@
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
import chess
|
||||
import chess.svg
|
||||
import io
|
||||
|
||||
from vacuumDecay import State, Action, Runtime, NeuralRuntime, Trainer, choose, main
|
||||
|
||||
class ChessAction(Action):
|
||||
def __init__(self, player, data):
|
||||
super().__init__(player, data)
|
||||
|
||||
def __str__(self):
|
||||
return "<P"+str(self.player)+"-"+self.data.uci()+">"
|
||||
|
||||
def getImage(self, state=None):
|
||||
return Image.open(io.BytesIO(chess.svg.board(board=state.board, format='png', squares=[self.data.from_square, self.data.to_square], arrows=[self.move])))
|
||||
|
||||
def getTensor(self, state):
|
||||
board, additionals = state.getTensor()
|
||||
|
||||
tensor = np.zeros((8, 8), dtype=int) # 13 channels for piece types and move squares
|
||||
|
||||
# Mark the from_square and to_square
|
||||
from_row, from_col = divmod(self.data.from_square, 8)
|
||||
to_row, to_col = divmod(self.data.to_square, 8)
|
||||
|
||||
tensor[from_row, from_col] = 1 # Mark the "from" square
|
||||
tensor[to_row, to_col] = 1 # Mark the "to" square
|
||||
|
||||
# Get the piece that was moved
|
||||
pieceT = np.zeros((12), dtype=int) # 13 channels for piece types and move squares
|
||||
piece = state.board.piece_at(self.data.from_square)
|
||||
if piece:
|
||||
piece_type = {
|
||||
'p': 0, 'n': 1, 'b': 2, 'r': 3, 'q': 4, 'k': 5,
|
||||
'P': 6, 'N': 7, 'B': 8, 'R': 9, 'Q': 10, 'K': 11
|
||||
}
|
||||
pieceT[piece_type[piece.symbol()]] = 1
|
||||
|
||||
# Flatten the tensor and return as a PyTorch tensor
|
||||
return (board, additionals, th.concat(tensor.flatten(), pieceT.flatten()))
|
||||
|
||||
piece_values = {
|
||||
chess.PAWN: 1,
|
||||
chess.KNIGHT: 3,
|
||||
chess.BISHOP: 3,
|
||||
chess.ROOK: 5,
|
||||
chess.QUEEN: 9
|
||||
}
|
||||
|
||||
class ChessState(State):
|
||||
def __init__(self, curPlayer=0, generation=0, board=None):
|
||||
if type(board) == type(None):
|
||||
board = chess.Board()
|
||||
self.curPlayer = curPlayer
|
||||
self.generation = generation
|
||||
self.playersNum = 2
|
||||
self.board = board
|
||||
|
||||
def mutate(self, action):
|
||||
newBoard = self.board.copy()
|
||||
newBoard.push(action.data)
|
||||
return ChessState(curPlayer=(self.curPlayer+1)%2, board=newBoard)
|
||||
|
||||
# Function to calculate total value of pieces for a player
|
||||
def calculate_piece_value(self, board, color):
|
||||
value = 0
|
||||
for square in chess.scan_reversed(board.occupied_co[color]):
|
||||
piece = board.piece_at(square)
|
||||
if piece is not None:
|
||||
value += piece_values.get(piece.piece_type, 0)
|
||||
return value
|
||||
|
||||
# Function to calculate winning probability for each player
|
||||
def calculate_winning_probability(self):
|
||||
white_piece_value = self.calculate_piece_value(self.board, chess.WHITE)
|
||||
black_piece_value = self.calculate_piece_value(self.board, chess.BLACK)
|
||||
total_piece_value = white_piece_value + black_piece_value
|
||||
|
||||
# Calculate winning probabilities
|
||||
white_probability = white_piece_value / total_piece_value
|
||||
black_probability = black_piece_value / total_piece_value
|
||||
|
||||
return white_probability, black_probability
|
||||
|
||||
def getScoreFor(self, player):
|
||||
w = self.checkWin()
|
||||
if w == None:
|
||||
return self.calculate_winning_probability()[player]
|
||||
if w == player:
|
||||
return 1
|
||||
if w == -1:
|
||||
return 0.1
|
||||
return 0
|
||||
|
||||
def getAvaibleActions(self):
|
||||
for move in self.board.legal_moves:
|
||||
yield ChessAction(self.curPlayer, move)
|
||||
|
||||
def checkWin(self):
|
||||
if self.board.is_checkmate():
|
||||
return (self.curPlayer+1)%2
|
||||
elif self.board.is_stalemate():
|
||||
return -1
|
||||
return None
|
||||
|
||||
def __str__(self):
|
||||
return str(self.board)
|
||||
|
||||
def getTensor(self):
|
||||
board = self.board
|
||||
piece_to_plane = {
|
||||
'P': 0, 'N': 1, 'B': 2, 'R': 3, 'Q': 4, 'K': 5,
|
||||
'p': 6, 'n': 7, 'b': 8, 'r': 9, 'q': 10, 'k': 11
|
||||
}
|
||||
|
||||
tensor = np.zeros((12, 8, 8), dtype=int)
|
||||
for square in chess.SQUARES:
|
||||
piece = board.piece_at(square)
|
||||
if piece:
|
||||
plane = piece_to_plane[piece.symbol()]
|
||||
row, col = divmod(square, 8)
|
||||
tensor[plane, row, col] = 1
|
||||
|
||||
# Side to move
|
||||
side_to_move = np.array([1 if board.turn == chess.WHITE else 0])
|
||||
|
||||
# Castling rights
|
||||
castling_rights = np.array([
|
||||
1 if board.has_kingside_castling_rights(chess.WHITE) else 0,
|
||||
1 if board.has_queenside_castling_rights(chess.WHITE) else 0,
|
||||
1 if board.has_kingside_castling_rights(chess.BLACK) else 0,
|
||||
1 if board.has_queenside_castling_rights(chess.BLACK) else 0
|
||||
])
|
||||
|
||||
# En passant target square
|
||||
en_passant = np.zeros((8, 8), dtype=int)
|
||||
if board.ep_square:
|
||||
row, col = divmod(board.ep_square, 8)
|
||||
en_passant[row, col] = 1
|
||||
|
||||
# Half-move clock and full-move number
|
||||
half_move_clock = np.array([board.halfmove_clock])
|
||||
full_move_number = np.array([board.fullmove_number])
|
||||
|
||||
additionals = np.concatenate([
|
||||
side_to_move,
|
||||
castling_rights,
|
||||
en_passant.flatten(),
|
||||
half_move_clock,
|
||||
full_move_number
|
||||
])
|
||||
|
||||
return (th.tensor(tensor), th.tensor(additionals))
|
||||
|
||||
@classmethod
|
||||
def getVModel():
|
||||
return ChessV()
|
||||
|
||||
@classmethod
|
||||
def getQModel():
|
||||
return ChessQ()
|
||||
|
||||
def getImage(self):
|
||||
return Image.open(io.BytesIO(chess.svg.board(board=self.board, format='png')))
|
||||
|
||||
class ChessV(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# CNN for the board tensor
|
||||
self.conv1 = nn.Conv2d(12, 16, kernel_size=3, padding=1)
|
||||
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
|
||||
self.fc1 = nn.Linear(32 * 8 * 8, 256)
|
||||
|
||||
# FCNN for the board tensor
|
||||
self.fc2 = nn.Linear(8 * 8, 64)
|
||||
|
||||
# FCNN for additional info
|
||||
self.fc_additional1 = nn.Linear(71, 64)
|
||||
|
||||
# Combine all outputs
|
||||
self.fc_combined1 = nn.Linear(256 + 64 + 64, 128)
|
||||
self.fc_combined2 = nn.Linear(128, 1)
|
||||
|
||||
def forward(self, inp):
|
||||
board_tensor, additional_info = inp
|
||||
# Process the board tensor through the CNN
|
||||
x = F.relu(self.conv1(board_tensor))
|
||||
x = F.relu(self.conv2(x))
|
||||
x = x.view(x.size(0), -1) # Flatten the tensor
|
||||
x = F.relu(self.fc1(x))
|
||||
|
||||
y = F.relu(self.fc2(board_tensor.view(board_tensor.size(0), -1)))
|
||||
|
||||
# Process the additional info through the FCNN
|
||||
z = F.relu(self.fc_additional1(additional_info))
|
||||
|
||||
# Combine the outputs
|
||||
combined = th.cat((x, y, z), dim=1)
|
||||
combined = F.relu(self.fc_combined1(combined))
|
||||
logit = self.fc_combined2(combined)
|
||||
|
||||
return logit
|
||||
|
||||
class ChessQ(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# CNN for the board tensor
|
||||
self.conv1 = nn.Conv2d(12, 16, kernel_size=3, padding=1)
|
||||
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
|
||||
self.fc1 = nn.Linear(32 * 8 * 8, 256)
|
||||
|
||||
# FCNN for the board tensor
|
||||
self.fc2 = nn.Linear(8 * 8, 64)
|
||||
|
||||
# FCNN for additional info
|
||||
self.fc_additional1 = nn.Linear(71, 64)
|
||||
|
||||
# Combine all outputs
|
||||
self.fc_combined1 = nn.Linear(256 + 64 + 64, 128)
|
||||
self.fc_combined2 = nn.Linear(128, 1)
|
||||
|
||||
def forward(self, inp):
|
||||
board_tensor, additional_info, action = inp
|
||||
# Process the board tensor through the CNN
|
||||
x = F.relu(self.conv1(board_tensor))
|
||||
x = F.relu(self.conv2(x))
|
||||
x = x.view(x.size(0), -1) # Flatten the tensor
|
||||
x = F.relu(self.fc1(x))
|
||||
|
||||
y = F.relu(self.fc2(board_tensor.view(board_tensor.size(0), -1)))
|
||||
|
||||
# Process the additional info through the FCNN
|
||||
z = F.relu(self.fc_additional1(additional_info))
|
||||
|
||||
# Combine the outputs
|
||||
combined = th.cat((x, y, z), dim=1)
|
||||
combined = F.relu(self.fc_combined1(combined))
|
||||
logit = self.fc_combined2(combined)
|
||||
|
||||
return logit
|
||||
|
||||
if __name__=="__main__":
|
||||
main(ChessState, start_visualizer=False)
|
@ -25,6 +25,9 @@ class TTTAction(Action):
|
||||
draw.line((x+40, y-40, x-40, y+40), fill='red', width=2)
|
||||
return img
|
||||
|
||||
def getTensor(self, state, player=None):
|
||||
return torch.concat(torch.tensor([self.turn]), torch.tensor(state.board), torch.tensor(state.mutate(self).board))
|
||||
|
||||
class TTTState(State):
|
||||
def __init__(self, curPlayer=0, generation=0, playersNum=2, board=None):
|
||||
if type(board) == type(None):
|
||||
@ -66,17 +69,29 @@ class TTTState(State):
|
||||
s.append(" ".join([str(p) if p!=None else '.' for p in self.board[l*3:][:3]]))
|
||||
return "\n".join(s)
|
||||
|
||||
def getTensor(self):
|
||||
return torch.tensor([self.turn] + self.board)
|
||||
def getTensor(self, player=None):
|
||||
return torch.concat(torch.tensor([self.curPlayer]), torch.tensor(self.board))
|
||||
|
||||
@classmethod
|
||||
def getModel():
|
||||
def getVModel(cls):
|
||||
return torch.nn.Sequential(
|
||||
torch.nn.Linear(10, 10),
|
||||
torch.nn.ReLu(),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(10, 3),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(3,1),
|
||||
torch.nn.Sigmoid(),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def getQModel(cls):
|
||||
return torch.nn.Sequential(
|
||||
torch.nn.Linear(20, 12),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(12, 3),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(3,1),
|
||||
torch.nn.Sigmoid(),
|
||||
torch.nn.Linear(3,1)
|
||||
)
|
||||
|
||||
def getImage(self):
|
||||
@ -98,4 +113,4 @@ class TTTState(State):
|
||||
return img
|
||||
|
||||
if __name__=="__main__":
|
||||
main(TTTState)
|
||||
main(TTTState, start_visualizer=False)
|
@ -3,7 +3,7 @@ A lot of this code was stolen from Pulkit Maloo (https://github.com/pulkitmaloo/
|
||||
"""
|
||||
import numpy as np
|
||||
import torch
|
||||
from troch import nn
|
||||
from torch import nn
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
from collections import Counter
|
||||
@ -11,8 +11,11 @@ import itertools
|
||||
|
||||
from vacuumDecay import State, Action, Runtime, NeuralRuntime, Trainer, choose, main
|
||||
|
||||
class UTTTAction(Action):
|
||||
def __init__(self, player, data):
|
||||
super().__init__(player, data)
|
||||
|
||||
class TTTState(State):
|
||||
class UTTTState(State):
|
||||
def __init__(self, curPlayer=0, generation=0, playersNum=2, board=None, lastMove=-1):
|
||||
if type(board) == type(None):
|
||||
board = "." * 81
|
||||
@ -48,7 +51,7 @@ class TTTState(State):
|
||||
def mutate(self, action):
|
||||
newBoard = self.board[:action.data] + ['O',
|
||||
'X'][self.curPlayer] + self.board[action.data+1:]
|
||||
return TTTState(curPlayer=(self.curPlayer+1) % self.playersNum, playersNum=self.playersNum, board=newBoard, lastMove=action.data)
|
||||
return UTTTState(curPlayer=(self.curPlayer+1) % self.playersNum, playersNum=self.playersNum, board=newBoard, lastMove=action.data)
|
||||
|
||||
def box(self, x, y):
|
||||
return self.index(x, y) // 9
|
||||
@ -67,7 +70,7 @@ class TTTState(State):
|
||||
def getAvaibleActions(self):
|
||||
if self.last_move == -1:
|
||||
for i in range(9*9):
|
||||
yield Action(self.curPlayer, i)
|
||||
yield UTTTAction(self.curPlayer, i)
|
||||
return
|
||||
|
||||
box_to_play = self.next_box(self.last_move)
|
||||
@ -83,19 +86,6 @@ class TTTState(State):
|
||||
if self.board[ind] == '.':
|
||||
yield Action(self.curPlayer, ind)
|
||||
|
||||
# def getScoreFor(self, player):
|
||||
# p = ['O','X'][player]
|
||||
# sco = 5
|
||||
# for w in self.box_won:
|
||||
# if w==p:
|
||||
# sco += 1
|
||||
# elif w!='.':
|
||||
# sco -= 0.5
|
||||
# return 1/sco
|
||||
|
||||
# def getPriority(self, score, cascadeMem):
|
||||
# return -cascadeMem*1 + 100
|
||||
|
||||
def checkWin(self):
|
||||
self.update_box_won()
|
||||
game_won = self.check_small_box(self.box_won)
|
||||
@ -147,11 +137,15 @@ class TTTState(State):
|
||||
return torch.tensor([self.symbToNum(b) for b in s])
|
||||
|
||||
@classmethod
|
||||
def getModel(cls, phase='default'):
|
||||
return Model()
|
||||
def getVModel(cls, phase='default'):
|
||||
return TTTV()
|
||||
|
||||
@classmethod
|
||||
def getQModel(cls, phase='default'):
|
||||
return TTTQ()
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
class TTTV(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@ -183,13 +177,6 @@ class Model(nn.Module):
|
||||
nn.Linear(self.chansPerSlot*9, self.chansComp),
|
||||
nn.ReLU(),
|
||||
nn.Linear(self.chansComp, 1),
|
||||
#nn.Linear(9*8, 32),
|
||||
# nn.ReLU(),
|
||||
#nn.Linear(32, 8),
|
||||
# nn.ReLU(),
|
||||
#nn.Linear(16*9, 12),
|
||||
# nn.ReLU(),
|
||||
#nn.Linear(12, 1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
@ -202,5 +189,54 @@ class Model(nn.Module):
|
||||
y = self.out(x)
|
||||
return y
|
||||
|
||||
class TTTQ(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.chansPerSmol = 24
|
||||
self.chansPerSlot = 8
|
||||
self.chansComp = 8
|
||||
|
||||
self.smol = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels=2,
|
||||
out_channels=self.chansPerSmol,
|
||||
kernel_size=(3, 3),
|
||||
stride=3,
|
||||
padding=0,
|
||||
),
|
||||
nn.ReLU()
|
||||
)
|
||||
self.comb = nn.Sequential(
|
||||
nn.Conv1d(
|
||||
in_channels=self.chansPerSmol,
|
||||
out_channels=self.chansPerSlot,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
),
|
||||
nn.ReLU()
|
||||
)
|
||||
self.out = nn.Sequential(
|
||||
nn.Linear(self.chansPerSlot*9*2, self.chansComp),
|
||||
nn.ReLU(),
|
||||
nn.Linear(self.chansComp, 4),
|
||||
nn.ReLU(),
|
||||
nn.Linear(4, 1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
a, b = x
|
||||
a = torch.reshape(a, (1, 9, 9))
|
||||
b = torch.reshape(b, (1, 9, 9))
|
||||
x = torch.stack((a,b))
|
||||
x = self.smol(x)
|
||||
x = torch.reshape(x, (self.chansPerSmol, 9))
|
||||
x = self.comb(x)
|
||||
x = torch.reshape(x, (-1,))
|
||||
y = self.out(x)
|
||||
return y
|
||||
|
||||
if __name__=="__main__":
|
||||
main(TTTState)
|
||||
main(UTTTState)
|
@ -1,204 +0,0 @@
|
||||
class Node:
|
||||
def __init__(self, state, universe=None, parent=None, lastAction=None):
|
||||
self.state = state
|
||||
if universe == None:
|
||||
print('[!] No Universe defined. Spawning one...')
|
||||
universe = Universe()
|
||||
self.universe = universe
|
||||
self.parent = parent
|
||||
self.lastAction = lastAction
|
||||
|
||||
self._childs = None
|
||||
self._scores = [None]*self.state.playersNum
|
||||
self._strongs = [None]*self.state.playersNum
|
||||
self._alive = True
|
||||
self._cascadeMemory = 0 # Used for our alternative to alpha-beta pruning
|
||||
|
||||
self.last_updated = time.time() # New attribute
|
||||
|
||||
def update(self):
|
||||
self.last_updated = time.time()
|
||||
if hasattr(self.universe, 'visualizer'):
|
||||
self.universe.visualizer.send_update()
|
||||
|
||||
def kill(self):
|
||||
self._alive = False
|
||||
|
||||
def revive(self):
|
||||
self._alive = True
|
||||
|
||||
@property
|
||||
def childs(self):
|
||||
if self._childs == None:
|
||||
self._expand()
|
||||
return self._childs
|
||||
|
||||
def _expand(self):
|
||||
self._childs = []
|
||||
actions = self.state.getAvaibleActions()
|
||||
for action in actions:
|
||||
newNode = Node(self.state.mutate(action),
|
||||
self.universe, self, action)
|
||||
self._childs.append(self.universe.merge(newNode))
|
||||
self.update()
|
||||
|
||||
def getStrongFor(self, player):
|
||||
if self._strongs[player] != None:
|
||||
return self._strongs[player]
|
||||
else:
|
||||
return self.getScoreFor(player)
|
||||
|
||||
def _pullStrong(self):
|
||||
strongs = [None]*self.playersNum
|
||||
for p in range(self.playersNum):
|
||||
cp = self.state.curPlayer
|
||||
if cp == p:
|
||||
best = float('inf')
|
||||
for c in self.childs:
|
||||
if c.getStrongFor(p) < best:
|
||||
best = c.getStrongFor(p)
|
||||
strongs[p] = best
|
||||
else:
|
||||
scos = [(c.getStrongFor(p), c.getStrongFor(cp)) for c in self.childs]
|
||||
scos.sort(key=lambda x: x[1])
|
||||
betterHalf = scos[:max(3, int(len(scos)/3))]
|
||||
myScores = [bh[0]**2 for bh in betterHalf]
|
||||
strongs[p] = sqrt(myScores[0]*0.75 + sum(myScores)/(len(myScores)*4))
|
||||
update = False
|
||||
for s in range(self.playersNum):
|
||||
if strongs[s] != self._strongs[s]:
|
||||
update = True
|
||||
break
|
||||
self._strongs = strongs
|
||||
if update:
|
||||
if self.parent != None:
|
||||
cascade = self.parent._pullStrong()
|
||||
else:
|
||||
cascade = 2
|
||||
self._cascadeMemory = self._cascadeMemory/2 + cascade
|
||||
self.update()
|
||||
return cascade + 1
|
||||
self._cascadeMemory /= 2
|
||||
return 0
|
||||
|
||||
def forceStrong(self, depth=3):
|
||||
if depth == 0:
|
||||
self.strongDecay()
|
||||
else:
|
||||
if len(self.childs):
|
||||
for c in self.childs:
|
||||
c.forceStrong(depth-1)
|
||||
else:
|
||||
self.strongDecay()
|
||||
self.update()
|
||||
|
||||
def decayEvent(self):
|
||||
for c in self.childs:
|
||||
c.strongDecay()
|
||||
self.update()
|
||||
|
||||
def strongDecay(self):
|
||||
if self._strongs == [None]*self.playersNum:
|
||||
if not self.scoresAvaible():
|
||||
self._calcScores()
|
||||
self._strongs = self._scores
|
||||
if self.parent:
|
||||
return self.parent._pullStrong()
|
||||
return 1
|
||||
return None
|
||||
|
||||
def getSelfScore(self):
|
||||
return self.getScoreFor(self.curPlayer)
|
||||
|
||||
def getScoreFor(self, player):
|
||||
if self._scores[player] == None:
|
||||
self._calcScore(player)
|
||||
self.update()
|
||||
return self._scores[player]
|
||||
|
||||
def scoreAvaible(self, player):
|
||||
return self._scores[player] != None
|
||||
|
||||
def scoresAvaible(self):
|
||||
for p in self._scores:
|
||||
if p == None:
|
||||
return False
|
||||
return True
|
||||
|
||||
def strongScoresAvaible(self):
|
||||
for p in self._strongs:
|
||||
if p == None:
|
||||
return False
|
||||
return True
|
||||
|
||||
def askUserForAction(self):
|
||||
return self.state.askUserForAction(self.avaibleActions)
|
||||
|
||||
def _calcScores(self):
|
||||
for p in range(self.state.playersNum):
|
||||
self._calcScore(p)
|
||||
|
||||
def _calcScore(self, player):
|
||||
winner = self._getWinner()
|
||||
if winner != None:
|
||||
if winner == player:
|
||||
self._scores[player] = 0.0
|
||||
elif winner == -1:
|
||||
self._scores[player] = 2/3
|
||||
else:
|
||||
self._scores[player] = 1.0
|
||||
self.update()
|
||||
return
|
||||
if self.universe.scoreProvider == 'naive':
|
||||
self._scores[player] = self.state.getScoreFor(player)
|
||||
elif self.universe.scoreProvider == 'neural':
|
||||
self._scores[player] = self.state.getScoreNeural(self.universe.model, player)
|
||||
else:
|
||||
raise Exception('Unknown Score-Provider')
|
||||
self.update()
|
||||
|
||||
def getPriority(self):
|
||||
return self.state.getPriority(self.getSelfScore(), self._cascadeMemory)
|
||||
|
||||
@property
|
||||
def playersNum(self):
|
||||
return self.state.playersNum
|
||||
|
||||
@property
|
||||
def avaibleActions(self):
|
||||
r = []
|
||||
for c in self.childs:
|
||||
r.append(c.lastAction)
|
||||
return r
|
||||
|
||||
@property
|
||||
def curPlayer(self):
|
||||
return self.state.curPlayer
|
||||
|
||||
def _getWinner(self):
|
||||
return self.state.checkWin()
|
||||
|
||||
def getWinner(self):
|
||||
if len(self.childs) == 0:
|
||||
return -1
|
||||
return self._getWinner()
|
||||
|
||||
def _activateEdge(self, dist=0):
|
||||
if not self.strongScoresAvaible():
|
||||
self.universe.newOpen(self)
|
||||
else:
|
||||
for c in self.childs:
|
||||
if c._cascadeMemory > 0.001*(dist-2) or random.random() < 0.01:
|
||||
c._activateEdge(dist=dist+1)
|
||||
self.update()
|
||||
|
||||
def __str__(self):
|
||||
s = []
|
||||
if self.lastAction == None:
|
||||
s.append("[ {ROOT} ]")
|
||||
else:
|
||||
s.append("[ -> "+str(self.lastAction)+" ]")
|
||||
s.append("[ turn: "+str(self.state.curPlayer)+" ]")
|
||||
s.append(str(self.state))
|
||||
s.append("[ score: "+str(self.getScoreFor(0))+" ]")
|
||||
return '\n'.join(s)
|
@ -23,25 +23,25 @@ def aiVsAiLoop(StateClass, start_visualizer=False):
|
||||
trainer = Trainer(init, start_visualizer=start_visualizer)
|
||||
trainer.train()
|
||||
|
||||
def humanVsNaive(StateClass, start_visualizer=False):
|
||||
def humanVsNaive(StateClass, start_visualizer=False, calcDepth=7):
|
||||
run = Runtime(StateClass(), start_visualizer=start_visualizer)
|
||||
run.game()
|
||||
run.game(calcDepth=calcDepth)
|
||||
|
||||
def main(StateClass):
|
||||
def main(StateClass, **kwargs):
|
||||
options = ['Play Against AI',
|
||||
'Play Against AI (AI begins)', 'Play Against AI (Fast Play)', 'Playground', 'Let AI train', 'Play against Naive']
|
||||
opt = choose('?', options)
|
||||
if opt == options[0]:
|
||||
humanVsAi(StateClass)
|
||||
humanVsAi(StateClass,**kwargs)
|
||||
elif opt == options[1]:
|
||||
humanVsAi(StateClass, bots=[1, 0])
|
||||
humanVsAi(StateClass, bots=[1, 0], **kwargs)
|
||||
elif opt == options[2]:
|
||||
humanVsAi(StateClass, depth=2, noBg=True)
|
||||
humanVsAi(StateClass, depth=2, noBg=True, **kwargs)
|
||||
elif opt == options[3]:
|
||||
humanVsAi(StateClass, bots=[None, None])
|
||||
humanVsAi(StateClass, bots=[None, None], **kwargs)
|
||||
elif opt == options[4]:
|
||||
aiVsAiLoop(StateClass)
|
||||
aiVsAiLoop(StateClass, **kwargs)
|
||||
elif opt == options[5]:
|
||||
humanVsNaive(StateClass)
|
||||
humanVsNaive(StateClass, **kwargs)
|
||||
else:
|
||||
aiVsAiLoop(StateClass)
|
||||
aiVsAiLoop(StateClass, **kwargs)
|
||||
|
@ -43,14 +43,14 @@ class Runtime():
|
||||
def __init__(self, initState, start_visualizer=False):
|
||||
universe = QueueingUniverse()
|
||||
self.head = Node(initState, universe=universe)
|
||||
self.root = self.head
|
||||
_ = self.head.childs
|
||||
universe.newOpen(self.head)
|
||||
self.visualizer = None
|
||||
if start_visualizer:
|
||||
self.startVisualizer()
|
||||
|
||||
def startVisualizer(self):
|
||||
self.visualizer = Visualizer(self.head.universe)
|
||||
self.visualizer = Visualizer(self)
|
||||
self.visualizer.start()
|
||||
|
||||
def spawnWorker(self):
|
||||
@ -85,11 +85,11 @@ class Runtime():
|
||||
self.head.forceStrong(calcDepth)
|
||||
opts = []
|
||||
for c in self.head.childs:
|
||||
opts.append((c, c.getStrongFor(self.head.curPlayer)))
|
||||
opts.sort(key=lambda x: x[1])
|
||||
opts.append((c, c.getStrongFor(self.head.curPlayer) + random.random()*0.000000001))
|
||||
opts.sort(key=lambda x: x[1], reverse=True)
|
||||
print('[i] Evaluated Options:')
|
||||
for o in opts:
|
||||
print('[ ]' + str(o[0].lastAction) + " (Score: "+str(o[1])+")")
|
||||
print('[ ]' + str(o[0].lastAction) + " (Win prob: "+str(int((o[1])*10000)/100)+"%)")
|
||||
print('[#] I choose to play: ' + str(opts[0][0].lastAction))
|
||||
self.performAction(opts[0][0].lastAction)
|
||||
else:
|
||||
@ -107,22 +107,23 @@ class Runtime():
|
||||
if bg:
|
||||
self.killWorker()
|
||||
|
||||
def saveModel(self, model, gen):
|
||||
dat = model.state_dict()
|
||||
def saveModel(self, v_model, q_model, gen):
|
||||
v_state = v_model.state_dict()
|
||||
q_model = q_model.state_dict()
|
||||
with open(self.getModelFileName(), 'wb') as f:
|
||||
pickle.dump((gen, dat), f)
|
||||
pickle.dump((gen, v_state, q_model), f)
|
||||
|
||||
def loadModelState(self, model):
|
||||
def loadModelState(self, v_model, q_model):
|
||||
with open(self.getModelFileName(), 'rb') as f:
|
||||
gen, dat = pickle.load(f)
|
||||
model.load_state_dict(dat)
|
||||
model.eval()
|
||||
gen, v_state, q_state = pickle.load(f)
|
||||
v_model.load_state_dict(v_state)
|
||||
q_model.load_state_dict(q_state)
|
||||
return gen
|
||||
|
||||
def loadModel(self):
|
||||
model = self.head.state.getModel()
|
||||
gen = self.loadModelState(model)
|
||||
return model, gen
|
||||
v_model, q_model = self.head.state.getVModel(), self.head.state.getQModel()
|
||||
gen = self.loadModelState(v_model, q_model)
|
||||
return v_model, q_model, gen
|
||||
|
||||
def getModelFileName(self):
|
||||
return 'brains/uttt.vac'
|
||||
@ -136,27 +137,29 @@ class NeuralRuntime(Runtime):
|
||||
def __init__(self, initState, **kwargs):
|
||||
super().__init__(initState, **kwargs)
|
||||
|
||||
model, gen = self.loadModel()
|
||||
v_model, q_model, gen = self.loadModel()
|
||||
|
||||
self.head.universe.model = model
|
||||
self.head.universe.v_model = v_model
|
||||
self.head.universe.q_model = q_model
|
||||
self.head.universe.scoreProvider = 'neural'
|
||||
|
||||
class Trainer(Runtime):
|
||||
def __init__(self, initState, **kwargs):
|
||||
super().__init__(initState, **kwargs)
|
||||
#self.universe = Universe()
|
||||
self.universe = self.head.universe
|
||||
self.rootNode = self.head
|
||||
self.terminal = None
|
||||
|
||||
def buildDatasetFromModel(self, model, depth=4, refining=True, fanOut=[5, 5, 5, 5, 4, 4, 4, 4], uncertainSec=15, exacity=5):
|
||||
def buildDatasetFromModel(self, v_model, q_model, depth=4, refining=True, fanOut=[5, 5, 5, 5, 4, 4, 4, 4], uncertainSec=15, exacity=5):
|
||||
print('[*] Building Timeline')
|
||||
term = self.linearPlay(model, calcDepth=depth, exacity=exacity)
|
||||
term = self.linearPlay(v_model, q_model, calcDepth=depth, exacity=exacity)
|
||||
if refining:
|
||||
print('[*] Refining Timeline (exploring alternative endings)')
|
||||
cur = term
|
||||
for d in fanOut:
|
||||
cur = cur.parent
|
||||
if cur == None:
|
||||
break
|
||||
cur.forceStrong(d)
|
||||
print('.', end='', flush=True)
|
||||
print('')
|
||||
@ -164,9 +167,10 @@ class Trainer(Runtime):
|
||||
self.timelineExpandUncertain(term, uncertainSec)
|
||||
return term
|
||||
|
||||
def linearPlay(self, model, calcDepth=7, exacity=5, verbose=False, firstNRandom=2):
|
||||
def linearPlay(self, v_model, q_model, calcDepth=7, exacity=5, verbose=False, firstNRandom=2):
|
||||
head = self.rootNode
|
||||
self.universe.model = model
|
||||
self.universe.v_model = v_model
|
||||
self.universe.q_model = q_model
|
||||
self.spawnWorker()
|
||||
while head.getWinner() == None:
|
||||
if verbose:
|
||||
@ -183,7 +187,7 @@ class Trainer(Runtime):
|
||||
firstNRandom -= 1
|
||||
ind = int(random.random()*len(opts))
|
||||
else:
|
||||
opts.sort(key=lambda x: x[1])
|
||||
opts.sort(key=lambda x: x[1], reverse=True)
|
||||
if exacity >= 10:
|
||||
ind = 0
|
||||
else:
|
||||
@ -236,31 +240,52 @@ class Trainer(Runtime):
|
||||
self.killWorker()
|
||||
print('')
|
||||
|
||||
def trainModel(self, model, lr=0.00005, cut=0.01, calcDepth=4, exacity=5, terms=None, batch=16):
|
||||
def trainModel(self, v_model, q_model, lr=0.00005, cut=0.01, calcDepth=4, exacity=5, terms=None, batch=2):
|
||||
loss_func = nn.MSELoss()
|
||||
optimizer = optim.Adam(model.parameters(), lr)
|
||||
v_optimizer = optim.Adam(v_model.parameters(), lr)
|
||||
q_optimizer = optim.Adam(q_model.parameters(), lr)
|
||||
print('[*] Conditioning Brain')
|
||||
if terms == None:
|
||||
terms = []
|
||||
for i in range(batch):
|
||||
terms.append(self.buildDatasetFromModel(
|
||||
model, depth=calcDepth, exacity=exacity))
|
||||
print('[*] Conditioning Brain')
|
||||
for r in range(64):
|
||||
v_model, q_model, depth=calcDepth, exacity=exacity))
|
||||
for r in range(16):
|
||||
loss_sum = 0
|
||||
lLoss = 0
|
||||
zeroLen = 0
|
||||
for i, node in enumerate(self.timelineIter(terms)):
|
||||
for p in range(self.rootNode.playersNum):
|
||||
inp = node.state.getTensor(player=p)
|
||||
gol = torch.tensor(
|
||||
v = torch.tensor(
|
||||
[node.getStrongFor(p)], dtype=torch.float)
|
||||
out = model(inp)
|
||||
loss = loss_func(out, gol)
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
loss_sum += loss.item()
|
||||
if loss.item() == 0.0:
|
||||
qs = []
|
||||
q_preds = []
|
||||
q_loss = torch.Tensor([0])
|
||||
if node.childs:
|
||||
for child in node.childs:
|
||||
sa = child.lastAction.getTensor(node.state, player=p)
|
||||
q = child.getStrongFor(p)
|
||||
q_pred = q_model(sa)
|
||||
qs.append(q)
|
||||
q_preds.append(q_pred)
|
||||
qs = torch.Tensor(qs)
|
||||
q_target = torch.zeros_like(qs).scatter_(0, torch.argmax(qs).unsqueeze(0), 1)
|
||||
q_cur = torch.concat(q_preds)
|
||||
q_loss = loss_func(q_cur, q_target)
|
||||
q_optimizer.zero_grad()
|
||||
q_loss.backward()
|
||||
q_optimizer.step()
|
||||
|
||||
v_pred = v_model(inp)
|
||||
v_loss = loss_func(v_pred, v)
|
||||
v_optimizer.zero_grad()
|
||||
v_loss.backward()
|
||||
v_optimizer.step()
|
||||
|
||||
loss = v_loss.item() + q_loss.item()
|
||||
loss_sum += loss
|
||||
if v_loss.item() == 0.0:
|
||||
zeroLen += 1
|
||||
if zeroLen == 5:
|
||||
break
|
||||
@ -270,31 +295,31 @@ class Trainer(Runtime):
|
||||
lLoss = loss_sum
|
||||
return loss_sum
|
||||
|
||||
def main(self, model=None, gens=1024, startGen=0):
|
||||
def main(self, v_model=None, q_model=None, gens=1024, startGen=0):
|
||||
newModel = False
|
||||
if model == None:
|
||||
if v_model == None or q_model==None:
|
||||
print('[!] No brain found. Creating new one...')
|
||||
newModel = True
|
||||
model = self.rootNode.state.getModel()
|
||||
v_model, q_model = self.rootNode.state.getVModel(), self.rootNode.state.getQModel()
|
||||
self.universe.scoreProvider = ['neural', 'naive'][newModel]
|
||||
model.train()
|
||||
v_model.train(), q_model.train()
|
||||
for gen in range(startGen, startGen+gens):
|
||||
print('[#####] Gen '+str(gen)+' training:')
|
||||
loss = self.trainModel(model, calcDepth=min(
|
||||
loss = self.trainModel(v_model, q_model, calcDepth=min(
|
||||
4, 3+int(gen/16)), exacity=int(gen/3+1), batch=4)
|
||||
print('[L] '+str(loss))
|
||||
self.universe.scoreProvider = 'neural'
|
||||
self.saveModel(model, gen)
|
||||
self.saveModel(v_model, q_model, gen)
|
||||
|
||||
def trainFromTerm(self, term):
|
||||
model, gen = self.loadModel()
|
||||
v_model, q_model, gen = self.loadModel()
|
||||
self.universe.scoreProvider = 'neural'
|
||||
self.trainModel(model, calcDepth=4, exacity=10, term=term)
|
||||
self.saveModel(model)
|
||||
self.trainModel(v_model, q_model, calcDepth=4, exacity=10, term=term)
|
||||
self.saveModel(v_model, q_model)
|
||||
|
||||
def train(self):
|
||||
if os.path.exists(self.getModelFileName()):
|
||||
model, gen = self.loadModel()
|
||||
self.main(model, startGen=gen+1)
|
||||
v_model, q_model, gen = self.loadModel()
|
||||
self.main(v_model, q_model, startGen=gen+1)
|
||||
else:
|
||||
self.main()
|
||||
|
@ -2,70 +2,124 @@
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<title>Game Tree Visualization</title>
|
||||
<title>Interactive Tree Visualization</title>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/socket.io/4.0.1/socket.io.js"></script>
|
||||
<script src="https://d3js.org/d3.v5.min.js"></script>
|
||||
<script src="//cdnjs.cloudflare.com/ajax/libs/socket.io/2.3.0/socket.io.js"></script>
|
||||
<style>
|
||||
.links line {
|
||||
stroke: #999;
|
||||
stroke-opacity: 0.6;
|
||||
stroke-width: 1.5px;
|
||||
}
|
||||
|
||||
.nodes rect {
|
||||
stroke: #fff;
|
||||
stroke-width: 1.5px;
|
||||
}
|
||||
|
||||
text {
|
||||
font: 10px sans-serif;
|
||||
pointer-events: none;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div id="graph"></div>
|
||||
<script>
|
||||
var socket = io.connect('http://' + document.domain + ':' + location.port);
|
||||
var socket = io();
|
||||
|
||||
var margin = {top: 20, right: 120, bottom: 20, left: 120},
|
||||
width = 960 - margin.right - margin.left,
|
||||
height = 800 - margin.top - margin.bottom;
|
||||
|
||||
var svg = d3.select("#graph").append("svg")
|
||||
.attr("width", window.innerWidth)
|
||||
.attr("height", window.innerHeight);
|
||||
.attr("width", width + margin.right + margin.left)
|
||||
.attr("height", height + margin.top + margin.bottom)
|
||||
.append("g")
|
||||
.attr("transform", "translate(" + margin.left + "," + margin.top + ")");
|
||||
|
||||
var simulation = d3.forceSimulation()
|
||||
.force("link", d3.forceLink().id(function(d) { return d.id; }).distance(100))
|
||||
.force("charge", d3.forceManyBody().strength(-300))
|
||||
.force("center", d3.forceCenter(window.innerWidth / 2, window.innerHeight / 2));
|
||||
var tree = d3.tree().size([height, width]);
|
||||
|
||||
var link = svg.append("g")
|
||||
.attr("class", "links")
|
||||
.selectAll("line");
|
||||
|
||||
var node = svg.append("g")
|
||||
.attr("class", "nodes")
|
||||
.selectAll("circle");
|
||||
var root;
|
||||
|
||||
socket.on('update', function(data) {
|
||||
var nodes = data.nodes;
|
||||
var edges = data.edges;
|
||||
console.log(data);
|
||||
|
||||
var stratify = d3.stratify()
|
||||
.id(function(d) { return d.id; })
|
||||
.parentId(function(d) { return d.parentId; });
|
||||
|
||||
try {
|
||||
root = stratify(data.nodes);
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
return;
|
||||
}
|
||||
|
||||
tree(root);
|
||||
|
||||
var link = svg.selectAll(".link")
|
||||
.data(root.links(), function(d) { return d.source.id + "-" + d.target.id; });
|
||||
|
||||
link = link.data(edges);
|
||||
link.exit().remove();
|
||||
link = link.enter().append("line").merge(link);
|
||||
|
||||
node = node.data(nodes);
|
||||
link.enter().append("path")
|
||||
.attr("class", "link")
|
||||
.merge(link)
|
||||
.attr("d", d3.linkHorizontal()
|
||||
.x(function(d) { return d.y; })
|
||||
.y(function(d) { return d.x; }));
|
||||
|
||||
var node = svg.selectAll(".node")
|
||||
.data(root.descendants(), function(d) { return d.id; });
|
||||
|
||||
node.exit().remove();
|
||||
node = node.enter().append("circle")
|
||||
.attr("r", 5)
|
||||
|
||||
var nodeEnter = node.enter().append("g")
|
||||
.attr("class", "node")
|
||||
.attr("transform", function(d) {
|
||||
return "translate(" + d.y + "," + d.x + ")";
|
||||
});
|
||||
|
||||
nodeEnter.append("rect")
|
||||
.attr("width", 40)
|
||||
.attr("height", 40)
|
||||
.attr("x", -20)
|
||||
.attr("y", -20)
|
||||
.attr("fill", function(d) {
|
||||
var age = Date.now() - d.last_updated;
|
||||
var age = Date.now() - d.data.last_updated;
|
||||
return d3.interpolateCool(Math.min(age / 10000, 1));
|
||||
})
|
||||
.merge(node);
|
||||
});
|
||||
|
||||
simulation.nodes(nodes)
|
||||
.on("tick", ticked);
|
||||
nodeEnter.append("image")
|
||||
.attr("xlink:href", function(d) { return d.data.image ? 'data:image/jpeg;base64,' + d.data.image : ''; })
|
||||
.attr("x", -20)
|
||||
.attr("y", -20)
|
||||
.attr("width", 40)
|
||||
.attr("height", 40);
|
||||
|
||||
simulation.force("link")
|
||||
.links(edges);
|
||||
nodeEnter.append("text")
|
||||
.attr("dy", -30)
|
||||
.attr("dx", 0)
|
||||
.text(function(d) { return "Player: " + (d.data.currentPlayer !== undefined ? d.data.currentPlayer : 'N/A'); });
|
||||
|
||||
simulation.alpha(1).restart();
|
||||
nodeEnter.append("text")
|
||||
.attr("dy", -15)
|
||||
.attr("dx", 0)
|
||||
.text(function(d) {
|
||||
if (d.data.winProbs && d.data.winProbs.length >= 2) {
|
||||
return "Win Probs: P0: " + d.data.winProbs[0].toFixed(2) + ", P1: " + d.data.winProbs[1].toFixed(2);
|
||||
} else {
|
||||
return "Win Probs: N/A";
|
||||
}
|
||||
});
|
||||
|
||||
node = nodeEnter.merge(node);
|
||||
|
||||
node.attr("transform", function(d) {
|
||||
return "translate(" + d.y + "," + d.x + ")";
|
||||
});
|
||||
});
|
||||
|
||||
function ticked() {
|
||||
link
|
||||
.attr("x1", function(d) { return d.source.x; })
|
||||
.attr("y1", function(d) { return d.source.y; })
|
||||
.attr("x2", function(d) { return d.target.x; })
|
||||
.attr("y2", function(d) { return d.target.y; });
|
||||
|
||||
node
|
||||
.attr("cx", function(d) { return d.x; })
|
||||
.attr("cy", function(d) { return d.y; });
|
||||
}
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
|
@ -3,14 +3,19 @@ import time
|
||||
import networkx as nx
|
||||
from flask import Flask, render_template, jsonify
|
||||
from flask_socketio import SocketIO, emit
|
||||
from io import BytesIO
|
||||
import base64
|
||||
|
||||
class Visualizer:
|
||||
def __init__(self, universe):
|
||||
self.universe = universe
|
||||
def __init__(self, runtime):
|
||||
self.runtime = runtime
|
||||
self.graph = nx.DiGraph()
|
||||
self.app = Flask(__name__)
|
||||
self.socketio = SocketIO(self.app)
|
||||
self.init_flask()
|
||||
self.update_thread = threading.Thread(target=self.update_periodically)
|
||||
self.update_thread.daemon = True
|
||||
self.update_thread.start()
|
||||
|
||||
def init_flask(self):
|
||||
@self.app.route('/')
|
||||
@ -19,36 +24,19 @@ class Visualizer:
|
||||
|
||||
@self.app.route('/data')
|
||||
def data():
|
||||
nodes_data = []
|
||||
edges_data = []
|
||||
for node in self.universe.iter():
|
||||
nodes_data.append({
|
||||
'id': id(node),
|
||||
'image': node.state.getImage().tobytes() if node.state.getImage() else None,
|
||||
'value': node.getScoreFor(node.state.curPlayer),
|
||||
'last_updated': node.last_updated
|
||||
})
|
||||
for child in node.childs:
|
||||
edges_data.append({'source': id(node), 'target': id(child)})
|
||||
return jsonify(nodes=nodes_data, edges=edges_data)
|
||||
return jsonify(self.get_data())
|
||||
|
||||
@self.socketio.on('connect')
|
||||
def handle_connect():
|
||||
print('Client connected')
|
||||
|
||||
def send_update(self):
|
||||
nodes_data = []
|
||||
edges_data = []
|
||||
for node in self.universe.iter():
|
||||
nodes_data.append({
|
||||
'id': id(node),
|
||||
'image': node.state.getImage().tobytes() if node.state.getImage() else None,
|
||||
'value': node.getScoreFor(node.state.curPlayer),
|
||||
'last_updated': node.last_updated
|
||||
})
|
||||
for child in node.childs:
|
||||
edges_data.append({'source': id(node), 'target': id(child)})
|
||||
self.socketio.emit('update', {'nodes': nodes_data, 'edges': edges_data})
|
||||
self.socketio.emit('update', self.get_data())
|
||||
|
||||
def update_periodically(self):
|
||||
while True:
|
||||
self.send_update()
|
||||
time.sleep(1)
|
||||
|
||||
def run(self):
|
||||
self.socketio.run(self.app, debug=True, use_reloader=False)
|
||||
@ -56,3 +44,33 @@ class Visualizer:
|
||||
def start(self):
|
||||
self.thread = threading.Thread(target=self.run)
|
||||
self.thread.start()
|
||||
|
||||
def get_data(self):
|
||||
nodes_data = []
|
||||
edges_data = []
|
||||
|
||||
def add_node_data(node, depth=0):
|
||||
img = None
|
||||
if node.state.getImage(): # depth <= 2:
|
||||
buffered = BytesIO()
|
||||
node.state.getImage().save(buffered, format="JPEG")
|
||||
img = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||
|
||||
nodes_data.append({
|
||||
'id': id(node),
|
||||
'parentId': id(node.parent) if node.parent else None,
|
||||
'image': img,
|
||||
'currentPlayer': node.state.curPlayer,
|
||||
'winProbs': [node.getStrongFor(i) for i in range(node.state.playersNum)],
|
||||
'last_updated': node.last_updated
|
||||
})
|
||||
|
||||
for child in node.childs:
|
||||
edges_data.append({'source': id(node), 'target': id(child)})
|
||||
add_node_data(child, depth=depth + 1)
|
||||
|
||||
head_node = self.runtime.head
|
||||
if head_node:
|
||||
add_node_data(head_node)
|
||||
|
||||
return {'nodes': nodes_data, 'edges': edges_data}
|
||||
|
Loading…
Reference in New Issue
Block a user