Fix eval code

This commit is contained in:
Dominik Moritz Roth 2024-05-27 10:28:51 +02:00
parent cd3bdb0bf8
commit 74d4da5eba
2 changed files with 98 additions and 66 deletions

140
main.py
View File

@ -6,7 +6,7 @@ import random, math
from utils import visualize_prediction, plot_delta_distribution from utils import visualize_prediction, plot_delta_distribution
from data_processing import download_and_extract_data, load_all_wavs, split_data_by_time, compute_topology_metrics from data_processing import download_and_extract_data, load_all_wavs, split_data_by_time, compute_topology_metrics
from models import LatentFCProjector, LatentRNNProjector, LatentFourierProjector,MiddleOut, Predictor from models import LatentFCProjector, LatentRNNProjector, LatentFourierProjector,MiddleOut, Predictor
from bitstream import IdentityEncoder, ArithmeticEncoder, Bzip2Encoder from bitstream import IdentityEncoder, ArithmeticEncoder, Bzip2Encoder, BinomialHuffmanEncoder
import wandb import wandb
from pycallgraph2 import PyCallGraph from pycallgraph2 import PyCallGraph
from pycallgraph2.output import GraphvizOutput from pycallgraph2.output import GraphvizOutput
@ -58,6 +58,8 @@ class SpikeRunner(Slate_Runner):
self.projector = LatentRNNProjector(latent_size=latent_size, input_size=input_size, **slate.consume(config, 'latent_projector', expand=True)).to(device) self.projector = LatentRNNProjector(latent_size=latent_size, input_size=input_size, **slate.consume(config, 'latent_projector', expand=True)).to(device)
elif latent_projector_type == 'fourier': elif latent_projector_type == 'fourier':
self.projector = LatentFourierProjector(latent_size=latent_size, input_size=input_size, **slate.consume(config, 'latent_projector', expand=True)).to(device) self.projector = LatentFourierProjector(latent_size=latent_size, input_size=input_size, **slate.consume(config, 'latent_projector', expand=True)).to(device)
else:
raise Exception('No such Latent Projector')
self.middle_out = MiddleOut(latent_size=latent_size, region_latent_size=region_latent_size, num_peers=self.num_peers, **slate.consume(config, 'middle_out', expand=True)).to(device) self.middle_out = MiddleOut(latent_size=latent_size, region_latent_size=region_latent_size, num_peers=self.num_peers, **slate.consume(config, 'middle_out', expand=True)).to(device)
self.predictor = Predictor(region_latent_size=region_latent_size, **slate.consume(config, 'predictor', expand=True)).to(device) self.predictor = Predictor(region_latent_size=region_latent_size, **slate.consume(config, 'predictor', expand=True)).to(device)
@ -84,6 +86,10 @@ class SpikeRunner(Slate_Runner):
self.encoder = ArithmeticEncoder() self.encoder = ArithmeticEncoder()
elif bitstream_type == 'bzip2': elif bitstream_type == 'bzip2':
self.encoder = Bzip2Encoder() self.encoder = Bzip2Encoder()
elif bitstream_type == 'binomHuffman':
self.encoder = BinomialHuffmanEncoder()
else:
raise Exception('No such Encoder')
# Optimizer # Optimizer
self.optimizer = torch.optim.Adam(list(self.projector.parameters()) + list(self.middle_out.parameters()) + list(self.predictor.parameters()), lr=self.learning_rate) self.optimizer = torch.optim.Adam(list(self.projector.parameters()) + list(self.middle_out.parameters()) + list(self.predictor.parameters()), lr=self.learning_rate)
@ -196,23 +202,16 @@ class SpikeRunner(Slate_Runner):
self.save_models(epoch) self.save_models(epoch)
print(f'Evaluation complete for epoch {epoch + 1}') print(f'Evaluation complete for epoch {epoch + 1}')
wandb.log({"epoch": epoch, "loss": total_loss}, step=epoch)
print(f'Epoch {epoch + 1}/{self.epochs}, Loss: {total_loss}')
if (epoch + 1) % self.eval_freq == 0:
print(f'Starting evaluation for epoch {epoch + 1}')
test_loss = self.evaluate_model(epoch)
if test_loss < best_test_score:
best_test_score = test_loss
self.save_models(epoch)
print(f'Evaluation complete for epoch {epoch + 1}')
def evaluate_model(self, epoch): def evaluate_model(self, epoch):
print('Evaluating model...') print('Evaluating model...')
device = self.device device = self.device
# Save the current mode of the models
projector_mode = self.projector.training
middle_out_mode = self.middle_out.training
predictor_mode = self.predictor.training
# Set models to evaluation mode
self.projector.eval() self.projector.eval()
self.middle_out.eval() self.middle_out.eval()
self.predictor.eval() self.predictor.eval()
@ -226,84 +225,91 @@ class SpikeRunner(Slate_Runner):
total_sequences = 0 total_sequences = 0
with torch.no_grad(): with torch.no_grad():
for lead_idx in range(len(self.test_data[:8])): min_length = min([len(seq) for seq in self.test_data])
lead_data = self.test_data[lead_idx]
true_data = []
predicted_data = []
delta_data = []
targets = []
min_length = min([len(seq) for seq in self.test_data]) errs = []
rels = []
derrs = []
for lead_idx in range(len(self.test_data)):
lead_data = self.test_data[lead_idx][:min_length]
indices = list(range(len(self.test_data)))
random.shuffle(indices)
# Initialize lists to store segments and peer metrics
stacked_segments = [] stacked_segments = []
peer_metrics = [] peer_metrics = []
targets = []
lasts = []
for i in range(0, len(lead_data) - self.input_size-1, self.input_size // 8): for i in range(0, len(lead_data) - self.input_size - 1, self.input_size // 8):
lead_segment = lead_data[i:i + self.input_size] lead_segment = lead_data[i:i + self.input_size]
inputs = torch.tensor(lead_segment, dtype=torch.float32).to(device) inputs = torch.tensor(lead_segment, dtype=torch.float32).to(device)
# Collect peer segments and metrics
peer_segments = [] peer_segments = []
for peer_idx in self.sorted_peer_indices[lead_idx]: for peer_idx in self.sorted_peer_indices[lead_idx]:
peer_segment = self.test_data[peer_idx][i:i + self.input_size][:min_length] peer_segment = self.test_data[peer_idx][:min_length][i:i + self.input_size]
peer_segments.append(torch.tensor(peer_segment, dtype=torch.float32).to(device)) peer_segments.append(torch.tensor(peer_segment, dtype=torch.float32).to(device))
peer_metric = torch.tensor([self.topology_matrix[lead_idx, peer_idx] for peer_idx in self.sorted_peer_indices[lead_idx]], dtype=torch.float32).to(device) peer_metric = torch.tensor([self.topology_matrix[lead_idx, peer_idx] for peer_idx in self.sorted_peer_indices[lead_idx]], dtype=torch.float32).to(device)
peer_metrics.append(peer_metric) peer_metrics.append(peer_metric)
# Stack segments to form the batch
stacked_segment = torch.stack([inputs] + peer_segments).to(device) stacked_segment = torch.stack([inputs] + peer_segments).to(device)
stacked_segments.append(stacked_segment) stacked_segments.append(stacked_segment)
target = lead_data[i + self.input_size + 1] target = lead_data[i + self.input_size + 1]
targets.append(target) targets.append(target)
last = lead_data[i + self.input_size]
lasts.append(last)
# Pass the batch through the projector latents = self.projector(torch.stack(stacked_segments) / self.value_scale)
latents = self.projector(torch.stack(stacked_segments))
my_latents = latents[:, 0, :] my_latent = latents[:, 0, :]
peer_latents = latents[:, 1:, :] peer_latents = latents[:, 1:, :]
# Pass through MiddleOut region_latent = self.middle_out(my_latent, peer_latents, torch.stack(peer_metrics))
new_latents = self.middle_out(my_latents, peer_latents, torch.stack(peer_metrics)) prediction = self.predictor(region_latent) * self.value_scale
# Predict using the predictor tar = torch.tensor(targets, dtype=torch.float32).unsqueeze(-1).to(device)
predictions = self.predictor(new_latents) las = torch.tensor(lasts, dtype=torch.float32).unsqueeze(-1).numpy()
loss = self.criterion(prediction, tar)
err = np.sum(np.abs(prediction.cpu().detach().numpy() - tar.cpu().detach().numpy()))
derr = np.sum(np.abs(las - tar.cpu().detach().numpy()))
rel = err / np.sum(tar.cpu().detach().numpy())
total_loss += loss.item()
derrs.append(derr / np.prod(tar.size()).item())
errs.append(err / np.prod(tar.size()).item())
rels.append(rel.item())
# Compute loss and store true and predicted data all_true.extend(tar.cpu().numpy())
for i, segment in enumerate(stacked_segments): all_predicted.extend(prediction.cpu().numpy())
for t in range(self.input_size): all_deltas.extend((tar.cpu().numpy() - prediction.cpu().numpy()).tolist())
target = torch.tensor(targets[i])
true_data.append(target.cpu().numpy())
predicted_data.append(predictions[i].cpu().numpy())
delta_data.append((target - predictions[i]).cpu().numpy())
loss = self.criterion(predictions[i].cpu(), target)
total_loss += loss.item()
# Append true and predicted data for this lead sequence
all_true.append(true_data)
all_predicted.append(predicted_data)
all_deltas.append(delta_data)
if self.full_compression: if self.full_compression:
# Bitstream encoding self.encoder.build_model(my_latent.cpu().numpy())
self.encoder.build_model(my_latents.cpu().numpy()) compressed_data = self.encoder.encode(my_latent.cpu().numpy())
compressed_data = self.encoder.encode(my_latents.cpu().numpy()) decompressed_data = self.encoder.decode(compressed_data, len(my_latent))
decompressed_data = self.encoder.decode(compressed_data, len(my_latents)) compression_ratio = len(my_latent) / len(compressed_data)
compression_ratio = len(my_latents) / len(compressed_data)
compression_ratios.append(compression_ratio) compression_ratios.append(compression_ratio)
# Check if decompressed data matches the original data if np.allclose(my_latent.cpu().numpy(), decompressed_data, atol=1e-5):
if np.allclose(my_latents.cpu().numpy(), decompressed_data, atol=1e-5):
exact_matches += 1 exact_matches += 1
total_sequences += 1 total_sequences += 1
avg_loss = total_loss / len(self.test_data) avg_loss = total_loss / len(self.test_data)
print(f'Epoch {epoch+1}, Evaluation Loss: {avg_loss}') tot_err = sum(errs) / len(errs)
wandb.log({"evaluation_loss": avg_loss}, step=epoch) tot_derr = sum(derrs) / len(derrs)
adv_delta = tot_derr / tot_err
approx_ratio = 1 / (sum(rels) / len(rels))
# Visualize delta distribution print(f'Epoch {epoch+1}, Evaluation Loss: {avg_loss}')
delta_plot_path = plot_delta_distribution(np.concatenate(all_deltas), epoch) wandb.log({"evaluation_loss": avg_loss, "err": tot_err, "approx_ratio": approx_ratio, "adv_delta": adv_delta}, step=epoch)
# Visualize predictions
visualize_prediction(all_true, all_predicted, all_deltas, epoch=epoch, num_points=1953, name='0.1s')
visualize_prediction(all_true, all_predicted, all_deltas, epoch=epoch, num_points=195, name='0.01s')
visualize_prediction(all_true, all_predicted, all_deltas, epoch=epoch, num_points=20, name='0.001s')
# Plot delta distribution
delta_plot_path = plot_delta_distribution(np.array(all_deltas), epoch)
wandb.log({"delta_distribution": wandb.Image(delta_plot_path)}, step=epoch) wandb.log({"delta_distribution": wandb.Image(delta_plot_path)}, step=epoch)
if self.full_compression: if self.full_compression:
@ -314,6 +320,22 @@ class SpikeRunner(Slate_Runner):
wandb.log({"average_compression_ratio": avg_compression_ratio}, step=epoch) wandb.log({"average_compression_ratio": avg_compression_ratio}, step=epoch)
wandb.log({"exact_match_percentage": exact_match_percentage}, step=epoch) wandb.log({"exact_match_percentage": exact_match_percentage}, step=epoch)
# Restore the original mode of the models
if projector_mode:
self.projector.train()
else:
self.projector.eval()
if middle_out_mode:
self.middle_out.train()
else:
self.middle_out.eval()
if predictor_mode:
self.predictor.train()
else:
self.predictor.eval()
print('Evaluation done for this epoch.') print('Evaluation done for this epoch.')
return avg_loss return avg_loss

View File

@ -14,39 +14,49 @@ def visualize_wav_data(sample_rate, data, title="WAV Data", num_points=None):
plt.ylabel('Amplitude') plt.ylabel('Amplitude')
plt.show() plt.show()
def visualize_prediction(true_data, predicted_data, delta_data, sample_rate, num_points=None, epoch=None): def visualize_prediction(true_data, predicted_data, delta_data, num_points=None, epoch=None, name=''):
"""Visualize the true data, predicted data, and deltas.""" """Visualize the true data, predicted data, deltas, and combined plot."""
if num_points: if num_points:
true_data = true_data[:num_points] true_data = true_data[:num_points]
predicted_data = predicted_data[:num_points] predicted_data = predicted_data[:num_points]
delta_data = delta_data[:num_points] delta_data = delta_data[:num_points]
plt.figure(figsize=(15, 5)) plt.figure(figsize=(20, 5))
plt.subplot(3, 1, 1) plt.subplot(2, 2, 1)
plt.plot(true_data, label='True Data') plt.plot(true_data, label='True Data')
plt.title('True Data') plt.title('True Data')
plt.xlabel('Sample') plt.xlabel('Sample')
plt.ylabel('Amplitude') plt.ylabel('Amplitude')
plt.subplot(3, 1, 2) plt.subplot(2, 2, 3)
plt.plot(predicted_data, label='Predicted Data', color='orange') plt.plot(predicted_data, label='Predicted Data', color='orange')
plt.title('Predicted Data') plt.title('Predicted Data')
plt.xlabel('Sample') plt.xlabel('Sample')
plt.ylabel('Amplitude') plt.ylabel('Amplitude')
plt.subplot(3, 1, 3) plt.subplot(2, 2, 4)
plt.plot(delta_data, label='Delta', color='red') plt.plot(delta_data, label='Delta', color='red')
plt.title('Delta') plt.title('Delta')
plt.xlabel('Sample') plt.xlabel('Sample')
plt.ylabel('Amplitude') plt.ylabel('Amplitude')
plt.subplot(2, 2, 2)
plt.plot(true_data, label='True Data')
plt.plot(predicted_data, label='Predicted Data', color='orange')
plt.plot(delta_data, label='Delta', color='red')
plt.title('Combined Data')
plt.xlabel('Sample')
plt.ylabel('Amplitude')
plt.legend()
plt.tight_layout() plt.tight_layout()
tmp_dir = os.getenv('TMPDIR', '/tmp') tmp_dir = os.getenv('TMPDIR', '/tmp')
file_path = os.path.join(tmp_dir, f'prediction_plot_{np.random.randint(1e6)}.png') file_path = os.path.join(tmp_dir, f'prediction_plot_{np.random.randint(1e6)}.png')
plt.savefig(file_path) plt.savefig(file_path)
plt.close() plt.close()
wandb.log({"Prediction vs True Data": wandb.Image(file_path)}, step=epoch) wandb.log({f"Prediction vs True Data {name}": wandb.Image(file_path)}, step=epoch)
def plot_delta_distribution(deltas, epoch): def plot_delta_distribution(deltas, epoch):
"""Plot the distribution of deltas.""" """Plot the distribution of deltas."""