more wandb logging
This commit is contained in:
parent
16ba578737
commit
5eab625cae
13
main.py
13
main.py
@ -47,7 +47,7 @@ class SpikeRunner(Slate_Runner):
|
|||||||
latent_size = slate.consume(config, 'latent_projector.latent_size')
|
latent_size = slate.consume(config, 'latent_projector.latent_size')
|
||||||
input_size = slate.consume(config, 'latent_projector.input_size')
|
input_size = slate.consume(config, 'latent_projector.input_size')
|
||||||
region_latent_size = slate.consume(config, 'middle_out.region_latent_size')
|
region_latent_size = slate.consume(config, 'middle_out.region_latent_size')
|
||||||
device = slate.consume(training_config, 'device')
|
device = slate.consume(training_config, 'device', 'auto')
|
||||||
if device == 'auto':
|
if device == 'auto':
|
||||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
self.device = device
|
self.device = device
|
||||||
@ -110,6 +110,7 @@ class SpikeRunner(Slate_Runner):
|
|||||||
total_loss = 0
|
total_loss = 0
|
||||||
errs = []
|
errs = []
|
||||||
rels = []
|
rels = []
|
||||||
|
derrs = []
|
||||||
for batch_num in range(self.num_batches):
|
for batch_num in range(self.num_batches):
|
||||||
|
|
||||||
# Create indices for training data and shuffle them
|
# Create indices for training data and shuffle them
|
||||||
@ -119,6 +120,7 @@ class SpikeRunner(Slate_Runner):
|
|||||||
stacked_segments = []
|
stacked_segments = []
|
||||||
peer_metrics = []
|
peer_metrics = []
|
||||||
targets = []
|
targets = []
|
||||||
|
lasts = []
|
||||||
|
|
||||||
for idx in indices[:self.batch_size]:
|
for idx in indices[:self.batch_size]:
|
||||||
lead_data = self.train_data[idx][:min_length]
|
lead_data = self.train_data[idx][:min_length]
|
||||||
@ -143,6 +145,8 @@ class SpikeRunner(Slate_Runner):
|
|||||||
stacked_segments.append(stacked_segment)
|
stacked_segments.append(stacked_segment)
|
||||||
target = lead_data[i + self.input_size + 1]
|
target = lead_data[i + self.input_size + 1]
|
||||||
targets.append(target)
|
targets.append(target)
|
||||||
|
last = lead_data[i + self.input_size]
|
||||||
|
lasts.append(last)
|
||||||
|
|
||||||
# Pass the batch through the projector
|
# Pass the batch through the projector
|
||||||
latents = self.projector(torch.stack(stacked_segments)/self.value_scale)
|
latents = self.projector(torch.stack(stacked_segments)/self.value_scale)
|
||||||
@ -164,10 +168,13 @@ class SpikeRunner(Slate_Runner):
|
|||||||
|
|
||||||
# Calculate loss and backpropagate
|
# Calculate loss and backpropagate
|
||||||
tar = torch.tensor(targets, dtype=torch.float32).unsqueeze(-1).to(device)
|
tar = torch.tensor(targets, dtype=torch.float32).unsqueeze(-1).to(device)
|
||||||
|
las = torch.tensor(lasts, dtype=torch.float32).unsqueeze(-1).numpy()
|
||||||
loss = self.criterion(prediction, tar)
|
loss = self.criterion(prediction, tar)
|
||||||
err = np.sum(np.abs(prediction.cpu().detach().numpy() - tar.cpu().detach().numpy()))
|
err = np.sum(np.abs(prediction.cpu().detach().numpy() - tar.cpu().detach().numpy()))
|
||||||
|
derr = np.sum(np.abs(las - tar.cpu().detach().numpy()))
|
||||||
rel = err / np.sum(tar.cpu().detach().numpy())
|
rel = err / np.sum(tar.cpu().detach().numpy())
|
||||||
total_loss += loss.item()
|
total_loss += loss.item()
|
||||||
|
derrs.append(derr/np.prod(tar.size()).item())
|
||||||
errs.append(err/np.prod(tar.size()).item())
|
errs.append(err/np.prod(tar.size()).item())
|
||||||
rels.append(rel.item())
|
rels.append(rel.item())
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
@ -175,8 +182,10 @@ class SpikeRunner(Slate_Runner):
|
|||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
|
|
||||||
tot_err = sum(errs)/len(errs)
|
tot_err = sum(errs)/len(errs)
|
||||||
|
tot_derr = sum(derrs)/len(derrs)
|
||||||
|
adv_delta = tot_derr / tot_err
|
||||||
approx_ratio = 1/(sum(rels)/len(rels))
|
approx_ratio = 1/(sum(rels)/len(rels))
|
||||||
wandb.log({"epoch": epoch, "loss": total_loss, "err": tot_err, "approx_ratio": approx_ratio}, step=epoch)
|
wandb.log({"epoch": epoch, "loss": total_loss, "err": tot_err, "approx_ratio": approx_ratio, "adv_delta": adv_delta}, step=epoch)
|
||||||
print(f'Epoch {epoch + 1}/{self.epochs}, Loss: {total_loss}')
|
print(f'Epoch {epoch + 1}/{self.epochs}, Loss: {total_loss}')
|
||||||
|
|
||||||
if self.eval_freq != -1 and (epoch + 1) % self.eval_freq == 0:
|
if self.eval_freq != -1 and (epoch + 1) % self.eval_freq == 0:
|
||||||
|
Loading…
Reference in New Issue
Block a user