BrokeBrokenn
This commit is contained in:
parent
de5137ecd3
commit
6cc2d84519
BIN
brains/uttt.pth
BIN
brains/uttt.pth
Binary file not shown.
@ -127,7 +127,11 @@ class TTTState(State):
|
|||||||
def getTensor(self, player=None, phase='default'):
|
def getTensor(self, player=None, phase='default'):
|
||||||
if player==None:
|
if player==None:
|
||||||
player = self.curPlayer
|
player = self.curPlayer
|
||||||
return torch.tensor([self.symbToNum(b) for b in self.board])
|
s = ''
|
||||||
|
for row in range(1, 10):
|
||||||
|
for col in range(1, 10):
|
||||||
|
s += self.board[self.index(row, col)]
|
||||||
|
return torch.tensor([self.symbToNum(b) for b in s])
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def getModel(cls, phase='default'):
|
def getModel(cls, phase='default'):
|
||||||
@ -138,8 +142,7 @@ class Model(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.smolChan = 12
|
self.smolChan = 12
|
||||||
self.bigChan = 5
|
self.compChan = 7
|
||||||
self.compChan = 3
|
|
||||||
|
|
||||||
self.smol = nn.Sequential(
|
self.smol = nn.Sequential(
|
||||||
nn.Conv2d(
|
nn.Conv2d(
|
||||||
@ -152,35 +155,26 @@ class Model(nn.Module):
|
|||||||
nn.ReLU()
|
nn.ReLU()
|
||||||
)
|
)
|
||||||
self.big = nn.Sequential(
|
self.big = nn.Sequential(
|
||||||
nn.Conv2d(
|
nn.Linear(self.smolChan*9, self.compChan),
|
||||||
in_channels=self.smolChan,
|
#nn.ReLU(),
|
||||||
out_channels=self.bigChan,
|
#nn.Linear(self.compChan, 1),
|
||||||
kernel_size=(3,3),
|
|
||||||
stride=3,
|
|
||||||
padding=0,
|
|
||||||
),
|
|
||||||
nn.ReLU()
|
|
||||||
)
|
|
||||||
self.out = nn.Sequential(
|
|
||||||
#nn.Linear(bigChan, 1),
|
|
||||||
nn.Linear(self.bigChan, self.compChan),
|
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Linear(self.compChan, 1),
|
nn.Linear(self.compChan, 3),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(3, 1),
|
||||||
nn.Sigmoid()
|
nn.Sigmoid()
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = torch.reshape(x, (1,9,9))
|
x = torch.reshape(x, (1,9,9))
|
||||||
x = self.smol(x)
|
x = self.smol(x)
|
||||||
x = self.big(x)
|
x = torch.reshape(x, (self.smolChan*9,))
|
||||||
x = torch.reshape(x, (self.bigChan,))
|
y = self.big(x)
|
||||||
#x = x.view(x.size(0), -1)
|
|
||||||
y = self.out(x)
|
|
||||||
return y
|
return y
|
||||||
|
|
||||||
if __name__=="__main__":
|
if __name__=="__main__":
|
||||||
run = NeuralRuntime(TTTState())
|
run = NeuralRuntime(TTTState())
|
||||||
run.game(None, 4)
|
run.game(None, 4)
|
||||||
|
|
||||||
#trainer = Trainer(TTTState())
|
trainer = Trainer(TTTState())
|
||||||
#trainer.train()
|
trainer.train()
|
||||||
|
@ -186,6 +186,8 @@ class Node():
|
|||||||
|
|
||||||
def _expand(self):
|
def _expand(self):
|
||||||
self._childs = []
|
self._childs = []
|
||||||
|
if self.getWinner()!=None:
|
||||||
|
return
|
||||||
actions = self.state.getAvaibleActions()
|
actions = self.state.getAvaibleActions()
|
||||||
for action in actions:
|
for action in actions:
|
||||||
newNode = Node(self.state.mutate(action), self.universe, self, action)
|
newNode = Node(self.state.mutate(action), self.universe, self, action)
|
||||||
@ -284,11 +286,17 @@ class Node():
|
|||||||
self._calcScore(p)
|
self._calcScore(p)
|
||||||
|
|
||||||
def _calcScore(self, player):
|
def _calcScore(self, player):
|
||||||
|
winner = self.getWinner()
|
||||||
|
if winner!=None:
|
||||||
|
if winner==player:
|
||||||
|
self._scores[player] = 0.0
|
||||||
|
else:
|
||||||
|
self._scores[player] = 1.0
|
||||||
|
return
|
||||||
if self.universe.scoreProvider == 'naive':
|
if self.universe.scoreProvider == 'naive':
|
||||||
self._scores[player] = self.state.getScoreFor(player)
|
self._scores[player] = self.state.getScoreFor(player)
|
||||||
elif self.universe.scoreProvider == 'neural':
|
elif self.universe.scoreProvider == 'neural':
|
||||||
self._scores[player] = self.state.getScoreNeural(self.universe.model, player)
|
self._scores[player] = self.state.getScoreNeural(self.universe.model, player)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise Exception('Uknown Score-Provider')
|
raise Exception('Uknown Score-Provider')
|
||||||
|
|
||||||
@ -329,7 +337,7 @@ class Node():
|
|||||||
s.append("[ -> "+str(self.lastAction)+" ]")
|
s.append("[ -> "+str(self.lastAction)+" ]")
|
||||||
s.append("[ turn: "+str(self.state.curPlayer)+" ]")
|
s.append("[ turn: "+str(self.state.curPlayer)+" ]")
|
||||||
s.append(str(self.state))
|
s.append(str(self.state))
|
||||||
s.append("[ score: "+str(self.getStrongFor(self.state.curPlayer))+" ]")
|
s.append("[ score: "+str(self.getScoreFor(0))+" ]")
|
||||||
return '\n'.join(s)
|
return '\n'.join(s)
|
||||||
|
|
||||||
def choose(txt, options):
|
def choose(txt, options):
|
||||||
@ -452,7 +460,7 @@ class Trainer(Runtime):
|
|||||||
self.rootNode = Node(initState, universe = self.universe)
|
self.rootNode = Node(initState, universe = self.universe)
|
||||||
self.terminal = None
|
self.terminal = None
|
||||||
|
|
||||||
def buildDatasetFromModel(self, model, depth=4, refining=False):
|
def buildDatasetFromModel(self, model, depth=4, refining=True):
|
||||||
print('[*] Building Timeline')
|
print('[*] Building Timeline')
|
||||||
term = self.linearPlay(model, calcDepth=depth)
|
term = self.linearPlay(model, calcDepth=depth)
|
||||||
if refining:
|
if refining:
|
||||||
@ -462,8 +470,8 @@ class Trainer(Runtime):
|
|||||||
self.fanOut(term.parent.parent, depth=depth+1)
|
self.fanOut(term.parent.parent, depth=depth+1)
|
||||||
return term
|
return term
|
||||||
|
|
||||||
def fanOut(self, head, depth=10):
|
def fanOut(self, head, depth=4):
|
||||||
for d in range(max(3, depth-3)):
|
for d in range(max(1, depth-2)):
|
||||||
head = head.parent
|
head = head.parent
|
||||||
head.forceStrong(depth)
|
head.forceStrong(depth)
|
||||||
|
|
||||||
@ -499,7 +507,7 @@ class Trainer(Runtime):
|
|||||||
loss_func = nn.MSELoss()
|
loss_func = nn.MSELoss()
|
||||||
optimizer = optim.Adam(model.parameters(), lr)
|
optimizer = optim.Adam(model.parameters(), lr)
|
||||||
term = self.buildDatasetFromModel(model, depth=calcDepth)
|
term = self.buildDatasetFromModel(model, depth=calcDepth)
|
||||||
for r in range(16):
|
for r in range(64):
|
||||||
loss_sum = 0
|
loss_sum = 0
|
||||||
zeroLen = 0
|
zeroLen = 0
|
||||||
for i, node in enumerate(self.timelineIter(term)):
|
for i, node in enumerate(self.timelineIter(term)):
|
||||||
|
Loading…
Reference in New Issue
Block a user