diff --git a/slate/slate.py b/slate/slate.py index f798e8e..5eed72b 100644 --- a/slate/slate.py +++ b/slate/slate.py @@ -2,6 +2,7 @@ import wandb import yaml import os import math +import time import random import copy import re @@ -307,16 +308,28 @@ class Slate(): wandbC = self.consume(runnerConf, 'wandb', {}, expand=True, delta_desc=runnerConf.pop('delta_desc', 'BASE')) if 'job_type' in wandbC and len(wandbC['job_type']) > 62: wandbC['job_type'] = "..."+wandbC['job_type'][-50:] - with wandb.init( - project=project, - config=copy.deepcopy(runnerConf), - reinit=self.consume(wandbC, 'reinit', DEFAULT_REINIT), - settings=wandb.Settings(**self.consume(wandbC, 'settings', {})), - **wandbC - ) as run: - runner = Runner(self, runnerConf) - runner.setup() - runner.run(run) + + retry = 5 + while retry: + try: + with wandb.init( + project=project, + config=copy.deepcopy(runnerConf), + reinit=self.consume(wandbC, 'reinit', DEFAULT_REINIT), + settings=wandb.Settings(**self.consume(wandbC, 'settings', {})), + **wandbC + ) as run: + runner = Runner(self, runnerConf) + runner.setup() + runner.run(run) + except wandb.errors.CommError as e: + retry -= 1 + if retry: + print('Catched CommErr; retrying...') + time.sleep(int(60*random.random())) + else: + print('Catched CommErr; not retrying') + raise e if runnerConf != {}: msg = ('Config was not completely consumed: ', runnerConf)