bug fixes
This commit is contained in:
parent
b4f6e87395
commit
2ce2e8c384
21
main.py
21
main.py
@ -13,6 +13,7 @@ from pycallgraph2.output import GraphvizOutput
|
|||||||
from slate import Slate, Slate_Runner
|
from slate import Slate, Slate_Runner
|
||||||
|
|
||||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
#device = 'cpu'
|
||||||
|
|
||||||
class SpikeRunner(Slate_Runner):
|
class SpikeRunner(Slate_Runner):
|
||||||
def setup(self, name):
|
def setup(self, name):
|
||||||
@ -101,6 +102,8 @@ class SpikeRunner(Slate_Runner):
|
|||||||
|
|
||||||
for epoch in range(self.epochs):
|
for epoch in range(self.epochs):
|
||||||
total_loss = 0
|
total_loss = 0
|
||||||
|
errs = []
|
||||||
|
rels = []
|
||||||
for batch_num in range(self.num_batches):
|
for batch_num in range(self.num_batches):
|
||||||
|
|
||||||
# Create indices for training data and shuffle them
|
# Create indices for training data and shuffle them
|
||||||
@ -115,7 +118,7 @@ class SpikeRunner(Slate_Runner):
|
|||||||
lead_data = self.train_data[idx][:min_length]
|
lead_data = self.train_data[idx][:min_length]
|
||||||
|
|
||||||
# Slide a window over the data with overlap
|
# Slide a window over the data with overlap
|
||||||
stride = max(1, self.input_size // 8) # Ensuring stride is at least 1
|
stride = max(1, self.input_size // 3) # Ensuring stride is at least 1
|
||||||
for i in range(0, len(lead_data) - self.input_size-1, stride):
|
for i in range(0, len(lead_data) - self.input_size-1, stride):
|
||||||
lead_segment = lead_data[i:i + self.input_size]
|
lead_segment = lead_data[i:i + self.input_size]
|
||||||
inputs = torch.tensor(lead_segment, dtype=torch.float32).to(device)
|
inputs = torch.tensor(lead_segment, dtype=torch.float32).to(device)
|
||||||
@ -123,9 +126,9 @@ class SpikeRunner(Slate_Runner):
|
|||||||
# Collect the segments for the current lead and its peers
|
# Collect the segments for the current lead and its peers
|
||||||
peer_segments = []
|
peer_segments = []
|
||||||
for peer_idx in self.sorted_peer_indices[idx]:
|
for peer_idx in self.sorted_peer_indices[idx]:
|
||||||
peer_segment = self.train_data[peer_idx][i:i + self.input_size][:min_length]
|
peer_segment = self.train_data[peer_idx][i:i + self.input_size]
|
||||||
peer_segments.append(torch.tensor(peer_segment, dtype=torch.float32).to(device))
|
peer_segments.append(torch.tensor(peer_segment, dtype=torch.float32).to(device))
|
||||||
peer_correlation = torch.tensor([self.correlation_matrix[idx, peer_idx] for peer_idx in self.sorted_peer_indices[idx]], dtype=torch.float32).to(device) # Shape: (num_peers)
|
peer_correlation = torch.tensor([self.correlation_matrix[idx, peer_idx] for peer_idx in self.sorted_peer_indices[idx]], dtype=torch.float32).to(device)
|
||||||
peer_correlations.append(peer_correlation)
|
peer_correlations.append(peer_correlation)
|
||||||
|
|
||||||
# Stack the segments to form the batch
|
# Stack the segments to form the batch
|
||||||
@ -145,13 +148,20 @@ class SpikeRunner(Slate_Runner):
|
|||||||
prediction = self.predictor(new_latent)
|
prediction = self.predictor(new_latent)
|
||||||
|
|
||||||
# Calculate loss and backpropagate
|
# Calculate loss and backpropagate
|
||||||
loss = self.criterion(prediction, torch.tensor(targets, dtype=torch.float32).unsqueeze(-1).to(device))
|
tar = torch.tensor(targets, dtype=torch.float32).unsqueeze(-1).to(device)
|
||||||
|
loss = self.criterion(prediction, tar)
|
||||||
|
err = np.sum(np.abs(prediction.cpu().detach().numpy() - tar.cpu().detach().numpy()))
|
||||||
|
rel = err / np.sum(tar.cpu().detach().numpy())
|
||||||
total_loss += loss.item()
|
total_loss += loss.item()
|
||||||
|
errs.append(err.item())
|
||||||
|
rels.append(rel.item())
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
|
|
||||||
wandb.log({"epoch": epoch, "loss": total_loss}, step=epoch)
|
tot_err = sum(errs)/len(errs)
|
||||||
|
tot_rel = sum(rels)/len(rels)
|
||||||
|
wandb.log({"epoch": epoch, "loss": total_loss, "err": tot_err, "rel": tot_rel}, step=epoch)
|
||||||
print(f'Epoch {epoch + 1}/{self.epochs}, Loss: {total_loss}')
|
print(f'Epoch {epoch + 1}/{self.epochs}, Loss: {total_loss}')
|
||||||
|
|
||||||
if self.eval_freq != -1 and (epoch + 1) % self.eval_freq == 0:
|
if self.eval_freq != -1 and (epoch + 1) % self.eval_freq == 0:
|
||||||
@ -281,7 +291,6 @@ class SpikeRunner(Slate_Runner):
|
|||||||
print('Evaluation done for this epoch.')
|
print('Evaluation done for this epoch.')
|
||||||
return avg_loss
|
return avg_loss
|
||||||
|
|
||||||
|
|
||||||
def save_models(self, epoch):
|
def save_models(self, epoch):
|
||||||
return
|
return
|
||||||
print('Saving models...')
|
print('Saving models...')
|
||||||
|
@ -37,8 +37,9 @@ class LatentRNNProjector(nn.Module):
|
|||||||
self.latent_size = latent_size
|
self.latent_size = latent_size
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
out, _ = self.rnn(x)
|
batch_1, batch_2, timesteps = x.size()
|
||||||
latent = self.fc(out)
|
out, _ = self.rnn(x.view(batch_1 * batch_2, timesteps))
|
||||||
|
latent = self.fc(out).view(batch_1, batch_2, self.latent_size)
|
||||||
return latent
|
return latent
|
||||||
|
|
||||||
class MiddleOut(nn.Module):
|
class MiddleOut(nn.Module):
|
||||||
@ -57,7 +58,7 @@ class MiddleOut(nn.Module):
|
|||||||
|
|
||||||
new_latents = torch.stack(new_latents)
|
new_latents = torch.stack(new_latents)
|
||||||
averaged_latent = torch.mean(new_latents, dim=0)
|
averaged_latent = torch.mean(new_latents, dim=0)
|
||||||
return my_latent - averaged_latent
|
return averaged_latent
|
||||||
|
|
||||||
class Predictor(nn.Module):
|
class Predictor(nn.Module):
|
||||||
def __init__(self, output_size, layer_shapes, activations):
|
def __init__(self, output_size, layer_shapes, activations):
|
||||||
|
Loading…
Reference in New Issue
Block a user