diff --git a/main.py b/main.py index 6b6e274..00b425a 100644 --- a/main.py +++ b/main.py @@ -47,7 +47,7 @@ class SpikeRunner(Slate_Runner): latent_size = slate.consume(config, 'latent_projector.latent_size') input_size = slate.consume(config, 'latent_projector.input_size') region_latent_size = slate.consume(config, 'middle_out.region_latent_size') - device = slate.consume(training_config, 'device') + device = slate.consume(training_config, 'device', 'auto') if device == 'auto': device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.device = device @@ -110,6 +110,7 @@ class SpikeRunner(Slate_Runner): total_loss = 0 errs = [] rels = [] + derrs = [] for batch_num in range(self.num_batches): # Create indices for training data and shuffle them @@ -119,6 +120,7 @@ class SpikeRunner(Slate_Runner): stacked_segments = [] peer_metrics = [] targets = [] + lasts = [] for idx in indices[:self.batch_size]: lead_data = self.train_data[idx][:min_length] @@ -143,6 +145,8 @@ class SpikeRunner(Slate_Runner): stacked_segments.append(stacked_segment) target = lead_data[i + self.input_size + 1] targets.append(target) + last = lead_data[i + self.input_size] + lasts.append(last) # Pass the batch through the projector latents = self.projector(torch.stack(stacked_segments)/self.value_scale) @@ -164,10 +168,13 @@ class SpikeRunner(Slate_Runner): # Calculate loss and backpropagate tar = torch.tensor(targets, dtype=torch.float32).unsqueeze(-1).to(device) + las = torch.tensor(lasts, dtype=torch.float32).unsqueeze(-1).numpy() loss = self.criterion(prediction, tar) err = np.sum(np.abs(prediction.cpu().detach().numpy() - tar.cpu().detach().numpy())) + derr = np.sum(np.abs(las - tar.cpu().detach().numpy())) rel = err / np.sum(tar.cpu().detach().numpy()) total_loss += loss.item() + derrs.append(derr/np.prod(tar.size()).item()) errs.append(err/np.prod(tar.size()).item()) rels.append(rel.item()) self.optimizer.zero_grad() @@ -175,8 +182,10 @@ class SpikeRunner(Slate_Runner): self.optimizer.step() tot_err = sum(errs)/len(errs) + tot_derr = sum(derrs)/len(derrs) + adv_delta = tot_derr / tot_err approx_ratio = 1/(sum(rels)/len(rels)) - wandb.log({"epoch": epoch, "loss": total_loss, "err": tot_err, "approx_ratio": approx_ratio}, step=epoch) + wandb.log({"epoch": epoch, "loss": total_loss, "err": tot_err, "approx_ratio": approx_ratio, "adv_delta": adv_delta}, step=epoch) print(f'Epoch {epoch + 1}/{self.epochs}, Loss: {total_loss}') if self.eval_freq != -1 and (epoch + 1) % self.eval_freq == 0: