More config values
This commit is contained in:
parent
cfa2b48207
commit
cd62505ef1
37
main.py
37
main.py
@ -12,8 +12,6 @@ from pycallgraph2 import PyCallGraph
|
||||
from pycallgraph2.output import GraphvizOutput
|
||||
from slate import Slate, Slate_Runner
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
#device = 'cpu'
|
||||
|
||||
class SpikeRunner(Slate_Runner):
|
||||
def setup(self, name):
|
||||
@ -49,6 +47,10 @@ class SpikeRunner(Slate_Runner):
|
||||
latent_size = slate.consume(config, 'latent_projector.latent_size')
|
||||
input_size = slate.consume(config, 'latent_projector.input_size')
|
||||
region_latent_size = slate.consume(config, 'middle_out.region_latent_size')
|
||||
device = slate.consume(training_config, 'device')
|
||||
if device == 'auto':
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self.device = device
|
||||
|
||||
if latent_projector_type == 'fc':
|
||||
self.projector = LatentFCProjector(latent_size=latent_size, input_size=input_size, **slate.consume(config, 'latent_projector', expand=True)).to(device)
|
||||
@ -68,7 +70,8 @@ class SpikeRunner(Slate_Runner):
|
||||
self.learning_rate = slate.consume(training_config, 'learning_rate')
|
||||
self.eval_freq = slate.consume(training_config, 'eval_freq')
|
||||
self.save_path = slate.consume(training_config, 'save_path')
|
||||
self.peer_gradients = slate.consume(training_config, 'peer_gradients')
|
||||
self.peer_gradients_factor = float(slate.consume(training_config, 'peer_gradients_factor', 1.0))
|
||||
self.value_scale = slate.consume(training_config, 'value_scale')
|
||||
|
||||
# Evaluation parameter
|
||||
self.full_compression = slate.consume(config, 'evaluation.full_compression', default=False)
|
||||
@ -98,8 +101,9 @@ class SpikeRunner(Slate_Runner):
|
||||
self.train_model()
|
||||
|
||||
def train_model(self):
|
||||
device = self.device
|
||||
min_length = min([len(seq) for seq in self.train_data])
|
||||
|
||||
|
||||
best_test_score = float('inf')
|
||||
|
||||
for epoch in range(self.epochs):
|
||||
@ -121,7 +125,8 @@ class SpikeRunner(Slate_Runner):
|
||||
|
||||
# Slide a window over the data with overlap
|
||||
stride = max(1, self.input_size // 3) # Ensuring stride is at least 1
|
||||
for i in range(0, len(lead_data) - self.input_size-1, stride):
|
||||
offset = np.random.randint(0, stride)
|
||||
for i in range(offset, len(lead_data) - self.input_size-1-offset, stride):
|
||||
lead_segment = lead_data[i:i + self.input_size]
|
||||
inputs = torch.tensor(lead_segment, dtype=torch.float32).to(device)
|
||||
|
||||
@ -140,16 +145,22 @@ class SpikeRunner(Slate_Runner):
|
||||
targets.append(target)
|
||||
|
||||
# Pass the batch through the projector
|
||||
latents = self.projector(torch.stack(stacked_segments))
|
||||
latents = self.projector(torch.stack(stacked_segments)/self.value_scale)
|
||||
|
||||
my_latent = latents[:, 0, :]
|
||||
peer_latents = latents[:, 1:, :]
|
||||
if not self.peer_gradients:
|
||||
|
||||
# Scale gradients during backwards pass as configured
|
||||
if self.peer_gradients_factor == 1.0:
|
||||
pass
|
||||
elif self.peer_gradients_factor == 0.0:
|
||||
peer_latents = peer_latents.detach()
|
||||
else:
|
||||
peer_latents.register_hook(lambda grad: grad*self.peer_gradients_factor)
|
||||
|
||||
# Pass through MiddleOut
|
||||
new_latent = self.middle_out(my_latent, peer_latents, torch.stack(peer_metrics))
|
||||
prediction = self.predictor(new_latent)
|
||||
region_latent = self.middle_out(my_latent, peer_latents, torch.stack(peer_metrics))
|
||||
prediction = self.predictor(region_latent)*self.value_scale
|
||||
|
||||
# Calculate loss and backpropagate
|
||||
tar = torch.tensor(targets, dtype=torch.float32).unsqueeze(-1).to(device)
|
||||
@ -157,15 +168,15 @@ class SpikeRunner(Slate_Runner):
|
||||
err = np.sum(np.abs(prediction.cpu().detach().numpy() - tar.cpu().detach().numpy()))
|
||||
rel = err / np.sum(tar.cpu().detach().numpy())
|
||||
total_loss += loss.item()
|
||||
errs.append(err.item())
|
||||
errs.append(err/np.prod(tar.size()).item())
|
||||
rels.append(rel.item())
|
||||
self.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
|
||||
tot_err = sum(errs)/len(errs)
|
||||
tot_rel = sum(rels)/len(rels)
|
||||
wandb.log({"epoch": epoch, "loss": total_loss, "err": tot_err, "rel": tot_rel}, step=epoch)
|
||||
approx_ratio = 1/(sum(rels)/len(rels))
|
||||
wandb.log({"epoch": epoch, "loss": total_loss, "err": tot_err, "approx_ratio": approx_ratio}, step=epoch)
|
||||
print(f'Epoch {epoch + 1}/{self.epochs}, Loss: {total_loss}')
|
||||
|
||||
if self.eval_freq != -1 and (epoch + 1) % self.eval_freq == 0:
|
||||
@ -191,6 +202,8 @@ class SpikeRunner(Slate_Runner):
|
||||
|
||||
def evaluate_model(self, epoch):
|
||||
print('Evaluating model...')
|
||||
device = self.device
|
||||
|
||||
self.projector.eval()
|
||||
self.middle_out.eval()
|
||||
self.predictor.eval()
|
||||
|
17
models.py
17
models.py
@ -53,30 +53,31 @@ class LatentFourierProjector(nn.Module):
|
||||
super(LatentFourierProjector, self).__init__()
|
||||
self.fourier_transform = FourierTransformLayer()
|
||||
layers = []
|
||||
|
||||
if pass_raw_len is None:
|
||||
pass_raw_len = input_size
|
||||
else:
|
||||
assert pass_raw_len <= input_size
|
||||
|
||||
in_features = pass_raw_len + (input_size // 2 + 1) * 2 # (input_size // 2 + 1) real + imaginary parts
|
||||
for i, out_features in enumerate(layer_shapes):
|
||||
layers.append(nn.Linear(in_features, out_features))
|
||||
if activations[i] != 'None':
|
||||
layers.append(get_activation(activations[i]))
|
||||
in_features = out_features
|
||||
|
||||
layers.append(nn.Linear(in_features, latent_size))
|
||||
self.fc = nn.Sequential(*layers)
|
||||
self.latent_size = latent_size
|
||||
self.pass_raw_len = pass_raw_len
|
||||
|
||||
def forward(self, x):
|
||||
# Apply Fourier Transform
|
||||
x_fft = self.fourier_transform(x)
|
||||
# Separate real and imaginary parts and combine them
|
||||
batch_1, batch_2, timesteps = x.size()
|
||||
x_fft = self.fourier_transform(x.view(batch_1 * batch_2, timesteps))
|
||||
x_fft_real_imag = torch.cat((x_fft.real, x_fft.imag), dim=-1)
|
||||
# Combine part of the raw input with Fourier features
|
||||
combined_input = torch.cat([x[:, -self.pass_raw_len:], x_fft_real_imag], dim=-1)
|
||||
# Process through fully connected layers
|
||||
combined_input = torch.cat([x.view(batch_1 * batch_2, timesteps)[:, -self.pass_raw_len:], x_fft_real_imag], dim=-1)
|
||||
latent = self.fc(combined_input)
|
||||
latent = latent.view(batch_1, batch_2, self.latent_size)
|
||||
return latent
|
||||
|
||||
class MiddleOut(nn.Module):
|
||||
@ -84,11 +85,15 @@ class MiddleOut(nn.Module):
|
||||
super(MiddleOut, self).__init__()
|
||||
if residual:
|
||||
assert latent_size == region_latent_size
|
||||
if num_peers == 0:
|
||||
assert latent_size == region_latent_size
|
||||
self.num_peers = num_peers
|
||||
self.fc = nn.Linear(latent_size * 2 + 1, region_latent_size)
|
||||
self.residual = residual
|
||||
|
||||
def forward(self, my_latent, peer_latents, peer_metrics):
|
||||
if self.num_peers == 0:
|
||||
return my_latent
|
||||
new_latents = []
|
||||
for p in range(peer_latents.shape[-2]):
|
||||
peer_latent, metric = peer_latents[:, p, :], peer_metrics[:, p]
|
||||
|
Loading…
Reference in New Issue
Block a user