Switch to using torchrl loggers
This commit is contained in:
parent
7ea0bdcec6
commit
1d1d9060f9
@ -1,36 +0,0 @@
|
|||||||
class Logger:
|
|
||||||
def __init__(self, push_interval=1):
|
|
||||||
self.data = {}
|
|
||||||
self.push_interval = push_interval
|
|
||||||
|
|
||||||
def log(self, key, value, epoch):
|
|
||||||
if key not in self.data:
|
|
||||||
self.data[key] = []
|
|
||||||
self.data[key].append((epoch, value))
|
|
||||||
|
|
||||||
def end_of_epoch(self, epoch):
|
|
||||||
if epoch % self.push_interval == 0:
|
|
||||||
self.push()
|
|
||||||
|
|
||||||
def push(self):
|
|
||||||
raise NotImplementedError("Push method should be implemented by subclasses")
|
|
||||||
|
|
||||||
class TerminalLogger(Logger):
|
|
||||||
def push(self):
|
|
||||||
for key, values in self.data.items():
|
|
||||||
for epoch, value in values:
|
|
||||||
print(f"Epoch {epoch}: {key} = {value}")
|
|
||||||
self.data = {}
|
|
||||||
|
|
||||||
class WandbLogger(Logger):
|
|
||||||
def __init__(self, project, entity, config, push_interval=1):
|
|
||||||
super().__init__(push_interval)
|
|
||||||
import wandb
|
|
||||||
self.wandb = wandb
|
|
||||||
self.wandb.init(project=project, entity=entity, config=config)
|
|
||||||
|
|
||||||
def push(self):
|
|
||||||
for key, values in self.data.items():
|
|
||||||
for epoch, value in values:
|
|
||||||
self.wandb.log({key: value, 'epoch': epoch})
|
|
||||||
self.data = {}
|
|
Loading…
Reference in New Issue
Block a user