More config values
This commit is contained in:
parent
cfa2b48207
commit
cd62505ef1
35
main.py
35
main.py
@ -12,8 +12,6 @@ from pycallgraph2 import PyCallGraph
|
|||||||
from pycallgraph2.output import GraphvizOutput
|
from pycallgraph2.output import GraphvizOutput
|
||||||
from slate import Slate, Slate_Runner
|
from slate import Slate, Slate_Runner
|
||||||
|
|
||||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
||||||
#device = 'cpu'
|
|
||||||
|
|
||||||
class SpikeRunner(Slate_Runner):
|
class SpikeRunner(Slate_Runner):
|
||||||
def setup(self, name):
|
def setup(self, name):
|
||||||
@ -49,6 +47,10 @@ class SpikeRunner(Slate_Runner):
|
|||||||
latent_size = slate.consume(config, 'latent_projector.latent_size')
|
latent_size = slate.consume(config, 'latent_projector.latent_size')
|
||||||
input_size = slate.consume(config, 'latent_projector.input_size')
|
input_size = slate.consume(config, 'latent_projector.input_size')
|
||||||
region_latent_size = slate.consume(config, 'middle_out.region_latent_size')
|
region_latent_size = slate.consume(config, 'middle_out.region_latent_size')
|
||||||
|
device = slate.consume(training_config, 'device')
|
||||||
|
if device == 'auto':
|
||||||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
self.device = device
|
||||||
|
|
||||||
if latent_projector_type == 'fc':
|
if latent_projector_type == 'fc':
|
||||||
self.projector = LatentFCProjector(latent_size=latent_size, input_size=input_size, **slate.consume(config, 'latent_projector', expand=True)).to(device)
|
self.projector = LatentFCProjector(latent_size=latent_size, input_size=input_size, **slate.consume(config, 'latent_projector', expand=True)).to(device)
|
||||||
@ -68,7 +70,8 @@ class SpikeRunner(Slate_Runner):
|
|||||||
self.learning_rate = slate.consume(training_config, 'learning_rate')
|
self.learning_rate = slate.consume(training_config, 'learning_rate')
|
||||||
self.eval_freq = slate.consume(training_config, 'eval_freq')
|
self.eval_freq = slate.consume(training_config, 'eval_freq')
|
||||||
self.save_path = slate.consume(training_config, 'save_path')
|
self.save_path = slate.consume(training_config, 'save_path')
|
||||||
self.peer_gradients = slate.consume(training_config, 'peer_gradients')
|
self.peer_gradients_factor = float(slate.consume(training_config, 'peer_gradients_factor', 1.0))
|
||||||
|
self.value_scale = slate.consume(training_config, 'value_scale')
|
||||||
|
|
||||||
# Evaluation parameter
|
# Evaluation parameter
|
||||||
self.full_compression = slate.consume(config, 'evaluation.full_compression', default=False)
|
self.full_compression = slate.consume(config, 'evaluation.full_compression', default=False)
|
||||||
@ -98,6 +101,7 @@ class SpikeRunner(Slate_Runner):
|
|||||||
self.train_model()
|
self.train_model()
|
||||||
|
|
||||||
def train_model(self):
|
def train_model(self):
|
||||||
|
device = self.device
|
||||||
min_length = min([len(seq) for seq in self.train_data])
|
min_length = min([len(seq) for seq in self.train_data])
|
||||||
|
|
||||||
best_test_score = float('inf')
|
best_test_score = float('inf')
|
||||||
@ -121,7 +125,8 @@ class SpikeRunner(Slate_Runner):
|
|||||||
|
|
||||||
# Slide a window over the data with overlap
|
# Slide a window over the data with overlap
|
||||||
stride = max(1, self.input_size // 3) # Ensuring stride is at least 1
|
stride = max(1, self.input_size // 3) # Ensuring stride is at least 1
|
||||||
for i in range(0, len(lead_data) - self.input_size-1, stride):
|
offset = np.random.randint(0, stride)
|
||||||
|
for i in range(offset, len(lead_data) - self.input_size-1-offset, stride):
|
||||||
lead_segment = lead_data[i:i + self.input_size]
|
lead_segment = lead_data[i:i + self.input_size]
|
||||||
inputs = torch.tensor(lead_segment, dtype=torch.float32).to(device)
|
inputs = torch.tensor(lead_segment, dtype=torch.float32).to(device)
|
||||||
|
|
||||||
@ -140,16 +145,22 @@ class SpikeRunner(Slate_Runner):
|
|||||||
targets.append(target)
|
targets.append(target)
|
||||||
|
|
||||||
# Pass the batch through the projector
|
# Pass the batch through the projector
|
||||||
latents = self.projector(torch.stack(stacked_segments))
|
latents = self.projector(torch.stack(stacked_segments)/self.value_scale)
|
||||||
|
|
||||||
my_latent = latents[:, 0, :]
|
my_latent = latents[:, 0, :]
|
||||||
peer_latents = latents[:, 1:, :]
|
peer_latents = latents[:, 1:, :]
|
||||||
if not self.peer_gradients:
|
|
||||||
|
# Scale gradients during backwards pass as configured
|
||||||
|
if self.peer_gradients_factor == 1.0:
|
||||||
|
pass
|
||||||
|
elif self.peer_gradients_factor == 0.0:
|
||||||
peer_latents = peer_latents.detach()
|
peer_latents = peer_latents.detach()
|
||||||
|
else:
|
||||||
|
peer_latents.register_hook(lambda grad: grad*self.peer_gradients_factor)
|
||||||
|
|
||||||
# Pass through MiddleOut
|
# Pass through MiddleOut
|
||||||
new_latent = self.middle_out(my_latent, peer_latents, torch.stack(peer_metrics))
|
region_latent = self.middle_out(my_latent, peer_latents, torch.stack(peer_metrics))
|
||||||
prediction = self.predictor(new_latent)
|
prediction = self.predictor(region_latent)*self.value_scale
|
||||||
|
|
||||||
# Calculate loss and backpropagate
|
# Calculate loss and backpropagate
|
||||||
tar = torch.tensor(targets, dtype=torch.float32).unsqueeze(-1).to(device)
|
tar = torch.tensor(targets, dtype=torch.float32).unsqueeze(-1).to(device)
|
||||||
@ -157,15 +168,15 @@ class SpikeRunner(Slate_Runner):
|
|||||||
err = np.sum(np.abs(prediction.cpu().detach().numpy() - tar.cpu().detach().numpy()))
|
err = np.sum(np.abs(prediction.cpu().detach().numpy() - tar.cpu().detach().numpy()))
|
||||||
rel = err / np.sum(tar.cpu().detach().numpy())
|
rel = err / np.sum(tar.cpu().detach().numpy())
|
||||||
total_loss += loss.item()
|
total_loss += loss.item()
|
||||||
errs.append(err.item())
|
errs.append(err/np.prod(tar.size()).item())
|
||||||
rels.append(rel.item())
|
rels.append(rel.item())
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
|
|
||||||
tot_err = sum(errs)/len(errs)
|
tot_err = sum(errs)/len(errs)
|
||||||
tot_rel = sum(rels)/len(rels)
|
approx_ratio = 1/(sum(rels)/len(rels))
|
||||||
wandb.log({"epoch": epoch, "loss": total_loss, "err": tot_err, "rel": tot_rel}, step=epoch)
|
wandb.log({"epoch": epoch, "loss": total_loss, "err": tot_err, "approx_ratio": approx_ratio}, step=epoch)
|
||||||
print(f'Epoch {epoch + 1}/{self.epochs}, Loss: {total_loss}')
|
print(f'Epoch {epoch + 1}/{self.epochs}, Loss: {total_loss}')
|
||||||
|
|
||||||
if self.eval_freq != -1 and (epoch + 1) % self.eval_freq == 0:
|
if self.eval_freq != -1 and (epoch + 1) % self.eval_freq == 0:
|
||||||
@ -191,6 +202,8 @@ class SpikeRunner(Slate_Runner):
|
|||||||
|
|
||||||
def evaluate_model(self, epoch):
|
def evaluate_model(self, epoch):
|
||||||
print('Evaluating model...')
|
print('Evaluating model...')
|
||||||
|
device = self.device
|
||||||
|
|
||||||
self.projector.eval()
|
self.projector.eval()
|
||||||
self.middle_out.eval()
|
self.middle_out.eval()
|
||||||
self.predictor.eval()
|
self.predictor.eval()
|
||||||
|
17
models.py
17
models.py
@ -53,30 +53,31 @@ class LatentFourierProjector(nn.Module):
|
|||||||
super(LatentFourierProjector, self).__init__()
|
super(LatentFourierProjector, self).__init__()
|
||||||
self.fourier_transform = FourierTransformLayer()
|
self.fourier_transform = FourierTransformLayer()
|
||||||
layers = []
|
layers = []
|
||||||
|
|
||||||
if pass_raw_len is None:
|
if pass_raw_len is None:
|
||||||
pass_raw_len = input_size
|
pass_raw_len = input_size
|
||||||
else:
|
else:
|
||||||
assert pass_raw_len <= input_size
|
assert pass_raw_len <= input_size
|
||||||
|
|
||||||
in_features = pass_raw_len + (input_size // 2 + 1) * 2 # (input_size // 2 + 1) real + imaginary parts
|
in_features = pass_raw_len + (input_size // 2 + 1) * 2 # (input_size // 2 + 1) real + imaginary parts
|
||||||
for i, out_features in enumerate(layer_shapes):
|
for i, out_features in enumerate(layer_shapes):
|
||||||
layers.append(nn.Linear(in_features, out_features))
|
layers.append(nn.Linear(in_features, out_features))
|
||||||
if activations[i] != 'None':
|
if activations[i] != 'None':
|
||||||
layers.append(get_activation(activations[i]))
|
layers.append(get_activation(activations[i]))
|
||||||
in_features = out_features
|
in_features = out_features
|
||||||
|
|
||||||
layers.append(nn.Linear(in_features, latent_size))
|
layers.append(nn.Linear(in_features, latent_size))
|
||||||
self.fc = nn.Sequential(*layers)
|
self.fc = nn.Sequential(*layers)
|
||||||
self.latent_size = latent_size
|
self.latent_size = latent_size
|
||||||
self.pass_raw_len = pass_raw_len
|
self.pass_raw_len = pass_raw_len
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# Apply Fourier Transform
|
batch_1, batch_2, timesteps = x.size()
|
||||||
x_fft = self.fourier_transform(x)
|
x_fft = self.fourier_transform(x.view(batch_1 * batch_2, timesteps))
|
||||||
# Separate real and imaginary parts and combine them
|
|
||||||
x_fft_real_imag = torch.cat((x_fft.real, x_fft.imag), dim=-1)
|
x_fft_real_imag = torch.cat((x_fft.real, x_fft.imag), dim=-1)
|
||||||
# Combine part of the raw input with Fourier features
|
combined_input = torch.cat([x.view(batch_1 * batch_2, timesteps)[:, -self.pass_raw_len:], x_fft_real_imag], dim=-1)
|
||||||
combined_input = torch.cat([x[:, -self.pass_raw_len:], x_fft_real_imag], dim=-1)
|
|
||||||
# Process through fully connected layers
|
|
||||||
latent = self.fc(combined_input)
|
latent = self.fc(combined_input)
|
||||||
|
latent = latent.view(batch_1, batch_2, self.latent_size)
|
||||||
return latent
|
return latent
|
||||||
|
|
||||||
class MiddleOut(nn.Module):
|
class MiddleOut(nn.Module):
|
||||||
@ -84,11 +85,15 @@ class MiddleOut(nn.Module):
|
|||||||
super(MiddleOut, self).__init__()
|
super(MiddleOut, self).__init__()
|
||||||
if residual:
|
if residual:
|
||||||
assert latent_size == region_latent_size
|
assert latent_size == region_latent_size
|
||||||
|
if num_peers == 0:
|
||||||
|
assert latent_size == region_latent_size
|
||||||
self.num_peers = num_peers
|
self.num_peers = num_peers
|
||||||
self.fc = nn.Linear(latent_size * 2 + 1, region_latent_size)
|
self.fc = nn.Linear(latent_size * 2 + 1, region_latent_size)
|
||||||
self.residual = residual
|
self.residual = residual
|
||||||
|
|
||||||
def forward(self, my_latent, peer_latents, peer_metrics):
|
def forward(self, my_latent, peer_latents, peer_metrics):
|
||||||
|
if self.num_peers == 0:
|
||||||
|
return my_latent
|
||||||
new_latents = []
|
new_latents = []
|
||||||
for p in range(peer_latents.shape[-2]):
|
for p in range(peer_latents.shape[-2]):
|
||||||
peer_latent, metric = peer_latents[:, p, :], peer_metrics[:, p]
|
peer_latent, metric = peer_latents[:, p, :], peer_metrics[:, p]
|
||||||
|
Loading…
Reference in New Issue
Block a user