vacuumDecay/vacuumDecay.py

632 lines
20 KiB
Python
Raw Normal View History

import os
import io
2022-03-21 14:27:16 +01:00
import time
import random
import threading
import torch
2022-04-15 00:32:48 +02:00
import torch.nn as nn
from torch import optim
from math import sqrt, pow, inf
2022-03-21 14:27:16 +01:00
#from multiprocessing import Event
from abc import ABC, abstractmethod
from threading import Event
from queue import PriorityQueue, Empty
2022-04-14 21:05:45 +02:00
from dataclasses import dataclass, field
from typing import Any
2022-04-15 11:18:34 +02:00
import random
2022-04-16 11:20:25 +02:00
import datetime
import pickle
2022-03-21 14:27:16 +01:00
class Action():
# Should hold the data representing an action
# Actions are applied to a State in State.mutate
def __init__(self, player, data):
self.player = player
self.data = data
def __eq__(self, other):
# This should be implemented differently
# Two actions of different generations will never be compared
if type(other) != type(self):
return False
return str(self.data) == str(other.data)
def __str__(self):
# should return visual representation of this action
# should start with < and end with >
return "<P"+str(self.player)+"-"+str(self.data)+">"
class State(ABC):
# Hold a representation of the current game-state
# Allows retriving avaible actions (getAvaibleActions) and applying them (mutate)
# Mutations return a new State and should not have any effect on the current State
# Allows checking itself for a win (checkWin) or scoring itself based on a simple heuristic (getScore)
# The calculated score should be 0 when won; higher when in a worse state; highest for loosing
# getPriority is used for prioritising certain Nodes / States when expanding / walking the tree
2022-04-13 22:49:38 +02:00
def __init__(self, curPlayer=0, generation=0, playersNum=2):
self.curPlayer = curPlayer
2022-03-21 14:27:16 +01:00
self.generation = generation
self.playersNum = playersNum
@abstractmethod
def mutate(self, action):
# Returns a new state with supplied action performed
# self should not be changed
2022-04-13 22:49:38 +02:00
return State(curPlayer=(self.curPlayer+1) % self.playersNum, generation=self.generation+1, playersNum=self.playersNum)
2022-03-21 14:27:16 +01:00
@abstractmethod
def getAvaibleActions(self):
# Should return an array of all possible actions
2022-04-14 21:05:45 +02:00
return []
2022-04-14 15:28:08 +02:00
def askUserForAction(self, actions):
return choose('What does player '+str(self.curPlayer)+' want to do?', actions)
2022-03-21 14:27:16 +01:00
# improveMe
2022-04-14 20:24:48 +02:00
def getPriority(self, score, cascadeMemory):
2022-03-21 14:27:16 +01:00
# Used for ordering the priority queue
# Priority should not change for the same root
# Lower prioritys get worked on first
2022-04-13 22:49:38 +02:00
# Higher generations should have higher priority
2022-04-14 11:38:08 +02:00
# Higher cascadeMemory (more influence on higher-order-scores) should have lower priority
return -cascadeMemory + 100
2022-03-21 14:27:16 +01:00
@abstractmethod
def checkWin(self):
# -1 -> Draw
# None -> Not ended
# n e N -> player n won
return None
# improveMe
2022-04-13 22:49:38 +02:00
def getScoreFor(self, player):
2022-03-21 14:27:16 +01:00
# 0 <= score <= 1; should return close to zero when we are winning
w = self.checkWin()
if w == None:
return 0.5
2022-04-13 22:49:38 +02:00
if w == player:
2022-03-21 14:27:16 +01:00
return 0
if w == -1:
return 0.9
return 1
@abstractmethod
def __str__(self):
# return visual rep of state
return "[#]"
@abstractmethod
2022-04-15 00:32:48 +02:00
def getTensor(self, player=None, phase='default'):
if player==None:
player = self.curPlayer
2022-03-21 14:27:16 +01:00
return torch.tensor([0])
@classmethod
2022-04-15 00:32:48 +02:00
def getModel(cls, phase='default'):
2022-03-21 14:27:16 +01:00
pass
2022-04-15 00:32:48 +02:00
def getScoreNeural(self, model, player=None, phase='default'):
return model(self.getTensor(player=player, phase=phase)).item()
2022-03-21 14:27:16 +01:00
2022-04-14 21:05:45 +02:00
class Universe():
def __init__(self):
self.scoreProvider = 'naive'
def newOpen(self, node):
pass
def merge(self, node):
return node
def clearPQ(self):
pass
def iter(self):
return []
def activateEdge(self, head):
pass
@dataclass(order=True)
class PQItem:
priority: int
data: Any=field(compare=False)
class QueueingUniverse(Universe):
def __init__(self):
super().__init__()
self.pq = PriorityQueue()
def newOpen(self, node):
item = PQItem(node.getPriority(), node)
self.pq.put(item)
def merge(self, node):
self.newOpen(node)
return node
def clearPQ(self):
self.pq = PriorityQueue()
def iter(self):
while True:
try:
yield self.pq.get(False).data
except Empty:
2022-04-15 19:16:00 +02:00
return None
2022-04-14 21:05:45 +02:00
def activateEdge(self, head):
head._activateEdge()
2022-03-21 14:27:16 +01:00
class Node():
2022-04-13 22:49:38 +02:00
def __init__(self, state, universe=None, parent=None, lastAction=None):
2022-03-21 14:27:16 +01:00
self.state = state
2022-04-13 22:49:38 +02:00
if universe==None:
2022-04-14 20:24:48 +02:00
print('[!] No Universe defined. Spawning one...')
2022-04-13 22:49:38 +02:00
universe = Universe()
2022-03-21 14:27:16 +01:00
self.universe = universe
self.parent = parent
self.lastAction = lastAction
2022-04-13 22:49:38 +02:00
self._childs = None
self._scores = [None]*self.state.playersNum
self._strongs = [None]*self.state.playersNum
self._alive = True
2022-04-14 11:38:08 +02:00
self._cascadeMemory = 0 # Used for our alternative to alpha-beta pruning
2022-03-21 14:27:16 +01:00
2022-04-13 22:49:38 +02:00
def kill(self):
self._alive = False
2022-04-14 11:38:08 +02:00
def revive(self):
self._alive = True
2022-04-13 22:49:38 +02:00
@property
def childs(self):
if self._childs == None:
self._expand()
return self._childs
def _expand(self):
self._childs = []
2022-03-21 14:27:16 +01:00
actions = self.state.getAvaibleActions()
for action in actions:
2022-04-13 22:49:38 +02:00
newNode = Node(self.state.mutate(action), self.universe, self, action)
self._childs.append(self.universe.merge(newNode))
2022-04-14 11:38:08 +02:00
def getStrongFor(self, player):
if self._strongs[player]!=None:
return self._strongs[player]
else:
return self.getScoreFor(player)
2022-04-13 22:49:38 +02:00
def _pullStrong(self): # Currently Expecti-Max
strongs = [None]*self.playersNum
for p in range(self.playersNum):
cp = self.state.curPlayer
if cp == p: # P owns the turn; controlls outcome
2022-04-14 15:28:08 +02:00
best = inf
2022-04-13 22:49:38 +02:00
for c in self.childs:
2022-04-14 11:38:08 +02:00
if c.getStrongFor(p) < best:
best = c.getStrongFor(p)
2022-04-13 22:49:38 +02:00
strongs[p] = best
else:
2022-04-14 11:38:08 +02:00
scos = [(c.getStrongFor(p), c.getStrongFor(cp)) for c in self.childs]
scos.sort(key=lambda x: x[1])
betterHalf = scos[:max(3,int(len(scos)/3))]
myScores = [bh[0]**2 for bh in betterHalf]
strongs[p] = sqrt(myScores[0]*0.75 + sum(myScores)/(len(myScores)*4))
2022-04-13 22:49:38 +02:00
update = False
for s in range(self.playersNum):
if strongs[s] != self._strongs[s]:
update = True
break
self._strongs = strongs
if update:
2022-04-14 11:38:08 +02:00
if self.parent!=None:
cascade = self.parent._pullStrong()
else:
cascade = 2
self._cascadeMemory = self._cascadeMemory/2 + cascade
return cascade + 1
self._cascadeMemory /= 2
return 0
2022-04-13 22:49:38 +02:00
def forceStrong(self, depth=3):
if depth==0:
self.strongDecay()
else:
2022-04-14 11:38:08 +02:00
if len(self.childs):
for c in self.childs:
c.forceStrong(depth-1)
else:
self.strongDecay()
2022-04-13 22:49:38 +02:00
2022-04-14 20:24:48 +02:00
def decayEvent(self):
for c in self.childs:
c.strongDecay()
2022-04-13 22:49:38 +02:00
def strongDecay(self):
if self._strongs == [None]*self.playersNum:
if not self.scoresAvaible():
self._calcScores()
self._strongs = self._scores
2022-04-14 20:24:48 +02:00
if self.parent:
return self.parent._pullStrong()
return 1
2022-04-14 11:38:08 +02:00
return None
2022-04-13 22:49:38 +02:00
def getSelfScore(self):
return self.getScoreFor(self.curPlayer)
def getScoreFor(self, player):
if self._scores[player] == None:
self._calcScore(player)
return self._scores[player]
def scoreAvaible(self, player):
return self._scores[player] != None
def scoresAvaible(self):
for p in self._scores:
if p==None:
return False
2022-03-21 14:27:16 +01:00
return True
2022-04-14 11:38:08 +02:00
def strongScoresAvaible(self):
for p in self._strongs:
if p==None:
return False
return True
2022-04-14 20:24:48 +02:00
def askUserForAction(self):
return self.state.askUserForAction(self.avaibleActions)
2022-04-13 22:49:38 +02:00
def _calcScores(self):
for p in range(self.state.playersNum):
self._calcScore(p)
2022-03-21 14:27:16 +01:00
2022-04-13 22:49:38 +02:00
def _calcScore(self, player):
2022-04-15 18:11:18 +02:00
winner = self._getWinner()
2022-04-15 01:52:22 +02:00
if winner!=None:
2022-04-15 18:11:18 +02:00
if winner==player:
2022-04-15 01:52:22 +02:00
self._scores[player] = 0.0
2022-04-15 18:11:18 +02:00
elif winner==-1:
self._scores[player] = 2/3
2022-04-15 01:52:22 +02:00
else:
self._scores[player] = 1.0
return
2022-04-14 20:24:48 +02:00
if self.universe.scoreProvider == 'naive':
self._scores[player] = self.state.getScoreFor(player)
2022-04-15 00:32:48 +02:00
elif self.universe.scoreProvider == 'neural':
self._scores[player] = self.state.getScoreNeural(self.universe.model, player)
2022-04-14 20:24:48 +02:00
else:
raise Exception('Uknown Score-Provider')
2022-03-21 14:27:16 +01:00
2022-04-14 20:24:48 +02:00
def getPriority(self):
return self.state.getPriority(self.getSelfScore(), self._cascadeMemory)
2022-04-13 22:49:38 +02:00
@property
def playersNum(self):
return self.state.playersNum
2022-03-21 14:27:16 +01:00
2022-04-13 22:49:38 +02:00
@property
def avaibleActions(self):
r = []
for c in self.childs:
r.append(c.lastAction)
return r
2022-03-21 14:27:16 +01:00
2022-04-13 22:49:38 +02:00
@property
def curPlayer(self):
return self.state.curPlayer
2022-04-15 18:11:18 +02:00
def _getWinner(self):
return self.state.checkWin()
2022-04-14 20:24:48 +02:00
def getWinner(self):
2022-04-15 17:15:55 +02:00
if len(self.childs)==0:
return -1
2022-04-15 18:11:18 +02:00
return self._getWinner()
2022-04-14 20:24:48 +02:00
2022-04-15 17:20:15 +02:00
def _activateEdge(self, dist=0):
2022-04-13 22:49:38 +02:00
if not self.strongScoresAvaible():
self.universe.newOpen(self)
else:
for c in self.childs:
2022-04-15 18:11:18 +02:00
if c._cascadeMemory > 0.001*(dist-2) or random.random()<0.01:
2022-04-15 17:20:15 +02:00
c._activateEdge(dist=dist+1)
2022-03-21 14:27:16 +01:00
def __str__(self):
s = []
if self.lastAction == None:
s.append("[ {ROOT} ]")
else:
s.append("[ -> "+str(self.lastAction)+" ]")
2022-04-13 22:49:38 +02:00
s.append("[ turn: "+str(self.state.curPlayer)+" ]")
2022-03-21 14:27:16 +01:00
s.append(str(self.state))
2022-04-15 01:52:22 +02:00
s.append("[ score: "+str(self.getScoreFor(0))+" ]")
2022-03-21 14:27:16 +01:00
return '\n'.join(s)
2022-04-13 22:49:38 +02:00
def choose(txt, options):
while True:
print('[*] '+txt)
for num,opt in enumerate(options):
print('['+str(num+1)+'] ' + str(opt))
inp = input('[> ')
try:
n = int(inp)
if n in range(1,len(options)+1):
return options[n-1]
except:
pass
for opt in options:
if inp==str(opt):
return opt
if len(inp)==1:
for opt in options:
if inp==str(opt)[0]:
return opt
print('[!] Invalid Input.')
2022-04-14 20:24:48 +02:00
class Worker():
def __init__(self, universe):
self.universe = universe
self._alive = True
def run(self):
import threading
self.thread = threading.Thread(target=self.runLocal)
self.thread.start()
def runLocal(self):
for i, node in enumerate(self.universe.iter()):
2022-04-15 19:16:00 +02:00
if node==None:
time.sleep(1)
2022-04-14 20:24:48 +02:00
if not self._alive:
return
node.decayEvent()
def kill(self):
self._alive = False
2022-04-15 19:16:00 +02:00
self.thread.join(15)
2022-04-14 20:24:48 +02:00
def revive(self):
self._alive = True
2022-04-13 22:49:38 +02:00
class Runtime():
def __init__(self, initState):
2022-04-14 20:24:48 +02:00
universe = QueueingUniverse()
self.head = Node(initState, universe = universe)
2022-04-15 11:18:34 +02:00
_ = self.head.childs
2022-04-14 20:24:48 +02:00
universe.newOpen(self.head)
def spawnWorker(self):
self.worker = Worker(self.head.universe)
self.worker.run()
def killWorker(self):
self.worker.kill()
2022-04-13 22:49:38 +02:00
def performAction(self, action):
for c in self.head.childs:
if action == c.lastAction:
self.head.universe.clearPQ()
self.head.kill()
self.head = c
self.head.universe.activateEdge(self.head)
return
raise Exception('No such action avaible...')
def turn(self, bot=None, calcDepth=3, bg=True):
2022-04-13 22:49:38 +02:00
print(str(self.head))
if bot==None:
2022-04-14 20:24:48 +02:00
c = choose('Select action?', ['human', 'bot', 'undo', 'qlen'])
2022-04-13 22:49:38 +02:00
if c=='undo':
self.head = self.head.parent
return
2022-04-14 20:24:48 +02:00
elif c=='qlen':
print(self.head.universe.pq.qsize())
return
2022-04-13 22:49:38 +02:00
bot = c=='bot'
if bot:
2022-04-14 15:28:08 +02:00
self.head.forceStrong(calcDepth)
2022-04-13 22:49:38 +02:00
opts = []
for c in self.head.childs:
2022-04-14 11:38:08 +02:00
opts.append((c, c.getStrongFor(self.head.curPlayer)))
2022-04-13 22:49:38 +02:00
opts.sort(key=lambda x: x[1])
print('[i] Evaluated Options:')
for o in opts:
#print('['+str(o[0])+']' + str(o[0].lastAction) + " (Score: "+str(o[1])+")")
print('[ ]' + str(o[0].lastAction) + " (Score: "+str(o[1])+")")
print('[#] I choose to play: ' + str(opts[0][0].lastAction))
self.performAction(opts[0][0].lastAction)
2022-03-21 14:27:16 +01:00
else:
2022-04-14 20:24:48 +02:00
action = self.head.askUserForAction()
2022-04-13 22:49:38 +02:00
self.performAction(action)
2022-03-21 14:27:16 +01:00
def game(self, bots=None, calcDepth=7, bg=True):
if bg:
self.spawnWorker()
2022-04-13 22:49:38 +02:00
if bots==None:
bots = [None]*self.head.playersNum
2022-04-14 20:24:48 +02:00
while self.head.getWinner()==None:
self.turn(bots[self.head.curPlayer], calcDepth, bg=True)
print(['O','X','No one'][self.head.getWinner()] + ' won!')
if bg:
self.killWorker()
2022-04-14 21:05:45 +02:00
2022-04-15 00:32:48 +02:00
class NeuralRuntime(Runtime):
def __init__(self, initState):
super().__init__(initState)
model = self.head.state.getModel()
model.load_state_dict(torch.load('brains/uttt.pth'))
model.eval()
self.head.universe.model = model
self.head.universe.scoreProvider = 'neural'
2022-04-14 21:05:45 +02:00
class Trainer(Runtime):
def __init__(self, initState):
2022-04-15 23:53:34 +02:00
super().__init__(initState)
#self.universe = Universe()
self.universe = self.head.universe
self.rootNode = self.head
2022-04-14 21:05:45 +02:00
self.terminal = None
2022-04-16 15:11:59 +02:00
def buildDatasetFromModel(self, model, depth=4, refining=True, fanOut=[5,5,5,5,4,4,4,4], uncertainSec=15, exacity=5):
2022-04-15 00:32:48 +02:00
print('[*] Building Timeline')
2022-04-15 11:18:34 +02:00
term = self.linearPlay(model, calcDepth=depth, exacity=exacity)
2022-04-15 00:32:48 +02:00
if refining:
2022-04-15 18:11:18 +02:00
print('[*] Refining Timeline (exploring alternative endings)')
2022-04-16 12:32:16 +02:00
cur = term
for d in fanOut:
cur = cur.parent
cur.forceStrong(d)
2022-04-16 15:11:59 +02:00
print('.', end='', flush=True)
print('')
2022-04-15 23:53:34 +02:00
print('[*] Refining Timeline (exploring uncertain regions)')
2022-04-16 12:32:16 +02:00
self.timelineExpandUncertain(term, uncertainSec)
2022-04-15 00:32:48 +02:00
return term
2022-04-16 15:11:59 +02:00
def linearPlay(self, model, calcDepth=7, exacity=5, verbose=False, firstNRandom=2):
2022-04-15 00:32:48 +02:00
head = self.rootNode
self.universe.model = model
2022-04-16 15:11:59 +02:00
self.spawnWorker()
2022-04-14 21:05:45 +02:00
while head.getWinner()==None:
2022-04-15 00:32:48 +02:00
if verbose:
print(head)
else:
print('.', end='', flush=True)
head.forceStrong(calcDepth)
2022-04-14 21:05:45 +02:00
opts = []
2022-04-15 00:32:48 +02:00
if len(head.childs)==0:
break
for c in head.childs:
opts.append((c, c.getStrongFor(head.curPlayer)))
2022-04-16 15:11:59 +02:00
if firstNRandom:
firstNRandom-=1
ind = int(random.random()*len(opts))
2022-04-15 11:18:34 +02:00
else:
2022-04-16 15:11:59 +02:00
opts.sort(key=lambda x: x[1])
if exacity >= 10:
ind = 0
else:
ind = int(pow(random.random(),exacity)*(len(opts)-1))
2022-04-14 21:05:45 +02:00
head = opts[ind][0]
2022-04-16 15:11:59 +02:00
self.killWorker()
2022-04-16 12:32:16 +02:00
if verbose:
print(head)
2022-04-16 15:11:59 +02:00
print(' => '+['O','X','No one'][head.getWinner()] + ' won!')
2022-04-14 21:05:45 +02:00
return head
2022-04-15 00:32:48 +02:00
def timelineIter(self, term):
head = term
while True:
yield head
2022-04-15 11:18:34 +02:00
if len(head.childs):
yield random.choice(head.childs)
2022-04-15 00:32:48 +02:00
if head.parent == None:
return
head = head.parent
2022-04-15 18:11:18 +02:00
def timelineExpandUncertain(self, term, secs):
self.rootNode.universe.clearPQ()
2022-04-15 23:53:34 +02:00
self.rootNode.universe.activateEdge(self.rootNode)
2022-04-15 18:11:18 +02:00
self.spawnWorker()
2022-04-16 15:11:59 +02:00
for s in range(secs):
time.sleep(1)
print('.', end='', flush=True)
2022-04-15 18:11:18 +02:00
self.rootNode.universe.clearPQ()
self.killWorker()
2022-04-16 15:11:59 +02:00
print('')
2022-04-15 18:11:18 +02:00
2022-04-15 14:14:41 +02:00
def trainModel(self, model, lr=0.00005, cut=0.01, calcDepth=4, exacity=5, term=None):
2022-04-15 00:32:48 +02:00
loss_func = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr)
2022-04-15 14:14:41 +02:00
if term==None:
term = self.buildDatasetFromModel(model, depth=calcDepth, exacity=exacity)
2022-04-16 15:11:59 +02:00
print('[*] Conditioning Brain')
2022-04-15 01:52:22 +02:00
for r in range(64):
2022-04-15 00:32:48 +02:00
loss_sum = 0
2022-04-15 11:18:34 +02:00
lLoss = 0
2022-04-15 00:32:48 +02:00
zeroLen = 0
for i, node in enumerate(self.timelineIter(term)):
for p in range(self.rootNode.playersNum):
inp = node.state.getTensor(player=p)
2022-04-16 15:11:59 +02:00
gol = torch.tensor([node.getStrongFor(p)], dtype=torch.float)
2022-04-15 00:32:48 +02:00
out = model(inp)
loss = loss_func(out, gol)
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_sum += loss.item()
if loss.item() == 0.0:
zeroLen+=1
if zeroLen == 5:
break
2022-04-15 11:18:34 +02:00
#print(loss_sum/i)
if r > 16 and (loss_sum/i < cut or lLoss == loss_sum):
2022-04-16 12:32:16 +02:00
return loss_sum
2022-04-15 11:18:34 +02:00
lLoss = loss_sum
2022-04-16 12:32:16 +02:00
return loss_sum
2022-04-15 00:32:48 +02:00
def main(self, model=None, gens=1024, startGen=0):
2022-04-15 00:32:48 +02:00
newModel = False
if model==None:
print('[!] No brain found. Creating new one...')
2022-04-15 00:32:48 +02:00
newModel = True
model = self.rootNode.state.getModel()
self.universe.scoreProvider = ['neural','naive'][newModel]
model.train()
2022-04-15 11:18:34 +02:00
for gen in range(startGen, startGen+gens):
2022-04-15 00:32:48 +02:00
print('[#####] Gen '+str(gen)+' training:')
loss = self.trainModel(model, calcDepth=min(4,3+int(gen/16)), exacity=int(gen/3+1))
2022-04-16 12:32:16 +02:00
print('[L] '+str(loss))
2022-04-15 00:32:48 +02:00
self.universe.scoreProvider = 'neural'
self.saveModel(model, gen)
2022-04-15 14:30:52 +02:00
def saveModel(self, model, gen):
dat = model.state_dict()
with open(self.getModelFileName(), 'wb') as f:
pickle.dump((gen, dat), f)
2022-04-15 00:32:48 +02:00
def loadModelState(self, model):
with open(self.getModelFileName(), 'rb') as f:
gen, dat = pickle.load(f)
model.load_state_dict(dat)
2022-04-15 00:32:48 +02:00
model.eval()
return gen
def loadModel(self):
model = self.rootNode.state.getModel()
gen = self.loadModelState(model)
return model, gen
def train(self):
if os.path.exists(self.getModelFileName()):
model, gen = self.loadModel()
self.main(model, startGen=gen+1)
else:
self.main()
def getModelFileName(self):
return 'brains/utt.vac'
2022-04-15 14:14:41 +02:00
def trainFromTerm(self, term):
model = self.rootNode.state.getModel()
model.load_state_dict(torch.load('brains/uttt.vac'))
2022-04-15 14:14:41 +02:00
model.eval()
self.universe.scoreProvider = 'neural'
self.trainModel(model, calcDepth=4, exacity=10, term=term)
2022-04-15 14:30:52 +02:00
self.saveModel(model)
2022-04-15 14:14:41 +02:00
def saveToMemoryBank(self, term):
return
2022-04-15 14:14:41 +02:00
with open('memoryBank/uttt/'+datetime.datetime.now().strftime('%Y-%m-%d_%H:%M:%S')+'_'+str(int(random.random()*99999))+'.vdm', 'wb') as f:
2022-04-16 11:20:25 +02:00
pickle.dump(term, f)
2022-04-15 14:14:41 +02:00