More config values

This commit is contained in:
Dominik Moritz Roth 2024-05-26 17:41:30 +02:00
parent cfa2b48207
commit cd62505ef1
2 changed files with 36 additions and 18 deletions

37
main.py
View File

@ -12,8 +12,6 @@ from pycallgraph2 import PyCallGraph
from pycallgraph2.output import GraphvizOutput
from slate import Slate, Slate_Runner
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device = 'cpu'
class SpikeRunner(Slate_Runner):
def setup(self, name):
@ -49,6 +47,10 @@ class SpikeRunner(Slate_Runner):
latent_size = slate.consume(config, 'latent_projector.latent_size')
input_size = slate.consume(config, 'latent_projector.input_size')
region_latent_size = slate.consume(config, 'middle_out.region_latent_size')
device = slate.consume(training_config, 'device')
if device == 'auto':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.device = device
if latent_projector_type == 'fc':
self.projector = LatentFCProjector(latent_size=latent_size, input_size=input_size, **slate.consume(config, 'latent_projector', expand=True)).to(device)
@ -68,7 +70,8 @@ class SpikeRunner(Slate_Runner):
self.learning_rate = slate.consume(training_config, 'learning_rate')
self.eval_freq = slate.consume(training_config, 'eval_freq')
self.save_path = slate.consume(training_config, 'save_path')
self.peer_gradients = slate.consume(training_config, 'peer_gradients')
self.peer_gradients_factor = float(slate.consume(training_config, 'peer_gradients_factor', 1.0))
self.value_scale = slate.consume(training_config, 'value_scale')
# Evaluation parameter
self.full_compression = slate.consume(config, 'evaluation.full_compression', default=False)
@ -98,8 +101,9 @@ class SpikeRunner(Slate_Runner):
self.train_model()
def train_model(self):
device = self.device
min_length = min([len(seq) for seq in self.train_data])
best_test_score = float('inf')
for epoch in range(self.epochs):
@ -121,7 +125,8 @@ class SpikeRunner(Slate_Runner):
# Slide a window over the data with overlap
stride = max(1, self.input_size // 3) # Ensuring stride is at least 1
for i in range(0, len(lead_data) - self.input_size-1, stride):
offset = np.random.randint(0, stride)
for i in range(offset, len(lead_data) - self.input_size-1-offset, stride):
lead_segment = lead_data[i:i + self.input_size]
inputs = torch.tensor(lead_segment, dtype=torch.float32).to(device)
@ -140,16 +145,22 @@ class SpikeRunner(Slate_Runner):
targets.append(target)
# Pass the batch through the projector
latents = self.projector(torch.stack(stacked_segments))
latents = self.projector(torch.stack(stacked_segments)/self.value_scale)
my_latent = latents[:, 0, :]
peer_latents = latents[:, 1:, :]
if not self.peer_gradients:
# Scale gradients during backwards pass as configured
if self.peer_gradients_factor == 1.0:
pass
elif self.peer_gradients_factor == 0.0:
peer_latents = peer_latents.detach()
else:
peer_latents.register_hook(lambda grad: grad*self.peer_gradients_factor)
# Pass through MiddleOut
new_latent = self.middle_out(my_latent, peer_latents, torch.stack(peer_metrics))
prediction = self.predictor(new_latent)
region_latent = self.middle_out(my_latent, peer_latents, torch.stack(peer_metrics))
prediction = self.predictor(region_latent)*self.value_scale
# Calculate loss and backpropagate
tar = torch.tensor(targets, dtype=torch.float32).unsqueeze(-1).to(device)
@ -157,15 +168,15 @@ class SpikeRunner(Slate_Runner):
err = np.sum(np.abs(prediction.cpu().detach().numpy() - tar.cpu().detach().numpy()))
rel = err / np.sum(tar.cpu().detach().numpy())
total_loss += loss.item()
errs.append(err.item())
errs.append(err/np.prod(tar.size()).item())
rels.append(rel.item())
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
tot_err = sum(errs)/len(errs)
tot_rel = sum(rels)/len(rels)
wandb.log({"epoch": epoch, "loss": total_loss, "err": tot_err, "rel": tot_rel}, step=epoch)
approx_ratio = 1/(sum(rels)/len(rels))
wandb.log({"epoch": epoch, "loss": total_loss, "err": tot_err, "approx_ratio": approx_ratio}, step=epoch)
print(f'Epoch {epoch + 1}/{self.epochs}, Loss: {total_loss}')
if self.eval_freq != -1 and (epoch + 1) % self.eval_freq == 0:
@ -191,6 +202,8 @@ class SpikeRunner(Slate_Runner):
def evaluate_model(self, epoch):
print('Evaluating model...')
device = self.device
self.projector.eval()
self.middle_out.eval()
self.predictor.eval()

View File

@ -53,30 +53,31 @@ class LatentFourierProjector(nn.Module):
super(LatentFourierProjector, self).__init__()
self.fourier_transform = FourierTransformLayer()
layers = []
if pass_raw_len is None:
pass_raw_len = input_size
else:
assert pass_raw_len <= input_size
in_features = pass_raw_len + (input_size // 2 + 1) * 2 # (input_size // 2 + 1) real + imaginary parts
for i, out_features in enumerate(layer_shapes):
layers.append(nn.Linear(in_features, out_features))
if activations[i] != 'None':
layers.append(get_activation(activations[i]))
in_features = out_features
layers.append(nn.Linear(in_features, latent_size))
self.fc = nn.Sequential(*layers)
self.latent_size = latent_size
self.pass_raw_len = pass_raw_len
def forward(self, x):
# Apply Fourier Transform
x_fft = self.fourier_transform(x)
# Separate real and imaginary parts and combine them
batch_1, batch_2, timesteps = x.size()
x_fft = self.fourier_transform(x.view(batch_1 * batch_2, timesteps))
x_fft_real_imag = torch.cat((x_fft.real, x_fft.imag), dim=-1)
# Combine part of the raw input with Fourier features
combined_input = torch.cat([x[:, -self.pass_raw_len:], x_fft_real_imag], dim=-1)
# Process through fully connected layers
combined_input = torch.cat([x.view(batch_1 * batch_2, timesteps)[:, -self.pass_raw_len:], x_fft_real_imag], dim=-1)
latent = self.fc(combined_input)
latent = latent.view(batch_1, batch_2, self.latent_size)
return latent
class MiddleOut(nn.Module):
@ -84,11 +85,15 @@ class MiddleOut(nn.Module):
super(MiddleOut, self).__init__()
if residual:
assert latent_size == region_latent_size
if num_peers == 0:
assert latent_size == region_latent_size
self.num_peers = num_peers
self.fc = nn.Linear(latent_size * 2 + 1, region_latent_size)
self.residual = residual
def forward(self, my_latent, peer_latents, peer_metrics):
if self.num_peers == 0:
return my_latent
new_latents = []
for p in range(peer_latents.shape[-2]):
peer_latent, metric = peer_latents[:, p, :], peer_metrics[:, p]