Fix eval code

This commit is contained in:
Dominik Moritz Roth 2024-05-27 10:28:51 +02:00
parent cd3bdb0bf8
commit 74d4da5eba
2 changed files with 98 additions and 66 deletions

140
main.py
View File

@ -6,7 +6,7 @@ import random, math
from utils import visualize_prediction, plot_delta_distribution
from data_processing import download_and_extract_data, load_all_wavs, split_data_by_time, compute_topology_metrics
from models import LatentFCProjector, LatentRNNProjector, LatentFourierProjector,MiddleOut, Predictor
from bitstream import IdentityEncoder, ArithmeticEncoder, Bzip2Encoder
from bitstream import IdentityEncoder, ArithmeticEncoder, Bzip2Encoder, BinomialHuffmanEncoder
import wandb
from pycallgraph2 import PyCallGraph
from pycallgraph2.output import GraphvizOutput
@ -58,6 +58,8 @@ class SpikeRunner(Slate_Runner):
self.projector = LatentRNNProjector(latent_size=latent_size, input_size=input_size, **slate.consume(config, 'latent_projector', expand=True)).to(device)
elif latent_projector_type == 'fourier':
self.projector = LatentFourierProjector(latent_size=latent_size, input_size=input_size, **slate.consume(config, 'latent_projector', expand=True)).to(device)
else:
raise Exception('No such Latent Projector')
self.middle_out = MiddleOut(latent_size=latent_size, region_latent_size=region_latent_size, num_peers=self.num_peers, **slate.consume(config, 'middle_out', expand=True)).to(device)
self.predictor = Predictor(region_latent_size=region_latent_size, **slate.consume(config, 'predictor', expand=True)).to(device)
@ -84,6 +86,10 @@ class SpikeRunner(Slate_Runner):
self.encoder = ArithmeticEncoder()
elif bitstream_type == 'bzip2':
self.encoder = Bzip2Encoder()
elif bitstream_type == 'binomHuffman':
self.encoder = BinomialHuffmanEncoder()
else:
raise Exception('No such Encoder')
# Optimizer
self.optimizer = torch.optim.Adam(list(self.projector.parameters()) + list(self.middle_out.parameters()) + list(self.predictor.parameters()), lr=self.learning_rate)
@ -196,23 +202,16 @@ class SpikeRunner(Slate_Runner):
self.save_models(epoch)
print(f'Evaluation complete for epoch {epoch + 1}')
wandb.log({"epoch": epoch, "loss": total_loss}, step=epoch)
print(f'Epoch {epoch + 1}/{self.epochs}, Loss: {total_loss}')
if (epoch + 1) % self.eval_freq == 0:
print(f'Starting evaluation for epoch {epoch + 1}')
test_loss = self.evaluate_model(epoch)
if test_loss < best_test_score:
best_test_score = test_loss
self.save_models(epoch)
print(f'Evaluation complete for epoch {epoch + 1}')
def evaluate_model(self, epoch):
print('Evaluating model...')
device = self.device
# Save the current mode of the models
projector_mode = self.projector.training
middle_out_mode = self.middle_out.training
predictor_mode = self.predictor.training
# Set models to evaluation mode
self.projector.eval()
self.middle_out.eval()
self.predictor.eval()
@ -226,84 +225,91 @@ class SpikeRunner(Slate_Runner):
total_sequences = 0
with torch.no_grad():
for lead_idx in range(len(self.test_data[:8])):
lead_data = self.test_data[lead_idx]
true_data = []
predicted_data = []
delta_data = []
targets = []
min_length = min([len(seq) for seq in self.test_data])
min_length = min([len(seq) for seq in self.test_data])
errs = []
rels = []
derrs = []
for lead_idx in range(len(self.test_data)):
lead_data = self.test_data[lead_idx][:min_length]
indices = list(range(len(self.test_data)))
random.shuffle(indices)
# Initialize lists to store segments and peer metrics
stacked_segments = []
peer_metrics = []
targets = []
lasts = []
for i in range(0, len(lead_data) - self.input_size-1, self.input_size // 8):
for i in range(0, len(lead_data) - self.input_size - 1, self.input_size // 8):
lead_segment = lead_data[i:i + self.input_size]
inputs = torch.tensor(lead_segment, dtype=torch.float32).to(device)
# Collect peer segments and metrics
peer_segments = []
for peer_idx in self.sorted_peer_indices[lead_idx]:
peer_segment = self.test_data[peer_idx][i:i + self.input_size][:min_length]
peer_segment = self.test_data[peer_idx][:min_length][i:i + self.input_size]
peer_segments.append(torch.tensor(peer_segment, dtype=torch.float32).to(device))
peer_metric = torch.tensor([self.topology_matrix[lead_idx, peer_idx] for peer_idx in self.sorted_peer_indices[lead_idx]], dtype=torch.float32).to(device)
peer_metrics.append(peer_metric)
# Stack segments to form the batch
stacked_segment = torch.stack([inputs] + peer_segments).to(device)
stacked_segments.append(stacked_segment)
target = lead_data[i + self.input_size + 1]
targets.append(target)
last = lead_data[i + self.input_size]
lasts.append(last)
# Pass the batch through the projector
latents = self.projector(torch.stack(stacked_segments))
latents = self.projector(torch.stack(stacked_segments) / self.value_scale)
my_latents = latents[:, 0, :]
my_latent = latents[:, 0, :]
peer_latents = latents[:, 1:, :]
# Pass through MiddleOut
new_latents = self.middle_out(my_latents, peer_latents, torch.stack(peer_metrics))
region_latent = self.middle_out(my_latent, peer_latents, torch.stack(peer_metrics))
prediction = self.predictor(region_latent) * self.value_scale
# Predict using the predictor
predictions = self.predictor(new_latents)
tar = torch.tensor(targets, dtype=torch.float32).unsqueeze(-1).to(device)
las = torch.tensor(lasts, dtype=torch.float32).unsqueeze(-1).numpy()
loss = self.criterion(prediction, tar)
err = np.sum(np.abs(prediction.cpu().detach().numpy() - tar.cpu().detach().numpy()))
derr = np.sum(np.abs(las - tar.cpu().detach().numpy()))
rel = err / np.sum(tar.cpu().detach().numpy())
total_loss += loss.item()
derrs.append(derr / np.prod(tar.size()).item())
errs.append(err / np.prod(tar.size()).item())
rels.append(rel.item())
# Compute loss and store true and predicted data
for i, segment in enumerate(stacked_segments):
for t in range(self.input_size):
target = torch.tensor(targets[i])
true_data.append(target.cpu().numpy())
predicted_data.append(predictions[i].cpu().numpy())
delta_data.append((target - predictions[i]).cpu().numpy())
loss = self.criterion(predictions[i].cpu(), target)
total_loss += loss.item()
# Append true and predicted data for this lead sequence
all_true.append(true_data)
all_predicted.append(predicted_data)
all_deltas.append(delta_data)
all_true.extend(tar.cpu().numpy())
all_predicted.extend(prediction.cpu().numpy())
all_deltas.extend((tar.cpu().numpy() - prediction.cpu().numpy()).tolist())
if self.full_compression:
# Bitstream encoding
self.encoder.build_model(my_latents.cpu().numpy())
compressed_data = self.encoder.encode(my_latents.cpu().numpy())
decompressed_data = self.encoder.decode(compressed_data, len(my_latents))
compression_ratio = len(my_latents) / len(compressed_data)
self.encoder.build_model(my_latent.cpu().numpy())
compressed_data = self.encoder.encode(my_latent.cpu().numpy())
decompressed_data = self.encoder.decode(compressed_data, len(my_latent))
compression_ratio = len(my_latent) / len(compressed_data)
compression_ratios.append(compression_ratio)
# Check if decompressed data matches the original data
if np.allclose(my_latents.cpu().numpy(), decompressed_data, atol=1e-5):
if np.allclose(my_latent.cpu().numpy(), decompressed_data, atol=1e-5):
exact_matches += 1
total_sequences += 1
avg_loss = total_loss / len(self.test_data)
print(f'Epoch {epoch+1}, Evaluation Loss: {avg_loss}')
wandb.log({"evaluation_loss": avg_loss}, step=epoch)
tot_err = sum(errs) / len(errs)
tot_derr = sum(derrs) / len(derrs)
adv_delta = tot_derr / tot_err
approx_ratio = 1 / (sum(rels) / len(rels))
# Visualize delta distribution
delta_plot_path = plot_delta_distribution(np.concatenate(all_deltas), epoch)
print(f'Epoch {epoch+1}, Evaluation Loss: {avg_loss}')
wandb.log({"evaluation_loss": avg_loss, "err": tot_err, "approx_ratio": approx_ratio, "adv_delta": adv_delta}, step=epoch)
# Visualize predictions
visualize_prediction(all_true, all_predicted, all_deltas, epoch=epoch, num_points=1953, name='0.1s')
visualize_prediction(all_true, all_predicted, all_deltas, epoch=epoch, num_points=195, name='0.01s')
visualize_prediction(all_true, all_predicted, all_deltas, epoch=epoch, num_points=20, name='0.001s')
# Plot delta distribution
delta_plot_path = plot_delta_distribution(np.array(all_deltas), epoch)
wandb.log({"delta_distribution": wandb.Image(delta_plot_path)}, step=epoch)
if self.full_compression:
@ -314,6 +320,22 @@ class SpikeRunner(Slate_Runner):
wandb.log({"average_compression_ratio": avg_compression_ratio}, step=epoch)
wandb.log({"exact_match_percentage": exact_match_percentage}, step=epoch)
# Restore the original mode of the models
if projector_mode:
self.projector.train()
else:
self.projector.eval()
if middle_out_mode:
self.middle_out.train()
else:
self.middle_out.eval()
if predictor_mode:
self.predictor.train()
else:
self.predictor.eval()
print('Evaluation done for this epoch.')
return avg_loss

View File

@ -14,39 +14,49 @@ def visualize_wav_data(sample_rate, data, title="WAV Data", num_points=None):
plt.ylabel('Amplitude')
plt.show()
def visualize_prediction(true_data, predicted_data, delta_data, sample_rate, num_points=None, epoch=None):
"""Visualize the true data, predicted data, and deltas."""
def visualize_prediction(true_data, predicted_data, delta_data, num_points=None, epoch=None, name=''):
"""Visualize the true data, predicted data, deltas, and combined plot."""
if num_points:
true_data = true_data[:num_points]
predicted_data = predicted_data[:num_points]
delta_data = delta_data[:num_points]
plt.figure(figsize=(15, 5))
plt.figure(figsize=(20, 5))
plt.subplot(3, 1, 1)
plt.subplot(2, 2, 1)
plt.plot(true_data, label='True Data')
plt.title('True Data')
plt.xlabel('Sample')
plt.ylabel('Amplitude')
plt.subplot(3, 1, 2)
plt.subplot(2, 2, 3)
plt.plot(predicted_data, label='Predicted Data', color='orange')
plt.title('Predicted Data')
plt.xlabel('Sample')
plt.ylabel('Amplitude')
plt.subplot(3, 1, 3)
plt.subplot(2, 2, 4)
plt.plot(delta_data, label='Delta', color='red')
plt.title('Delta')
plt.xlabel('Sample')
plt.ylabel('Amplitude')
plt.subplot(2, 2, 2)
plt.plot(true_data, label='True Data')
plt.plot(predicted_data, label='Predicted Data', color='orange')
plt.plot(delta_data, label='Delta', color='red')
plt.title('Combined Data')
plt.xlabel('Sample')
plt.ylabel('Amplitude')
plt.legend()
plt.tight_layout()
tmp_dir = os.getenv('TMPDIR', '/tmp')
file_path = os.path.join(tmp_dir, f'prediction_plot_{np.random.randint(1e6)}.png')
plt.savefig(file_path)
plt.close()
wandb.log({"Prediction vs True Data": wandb.Image(file_path)}, step=epoch)
wandb.log({f"Prediction vs True Data {name}": wandb.Image(file_path)}, step=epoch)
def plot_delta_distribution(deltas, epoch):
"""Plot the distribution of deltas."""