Fix eval code
This commit is contained in:
parent
cd3bdb0bf8
commit
74d4da5eba
140
main.py
140
main.py
@ -6,7 +6,7 @@ import random, math
|
||||
from utils import visualize_prediction, plot_delta_distribution
|
||||
from data_processing import download_and_extract_data, load_all_wavs, split_data_by_time, compute_topology_metrics
|
||||
from models import LatentFCProjector, LatentRNNProjector, LatentFourierProjector,MiddleOut, Predictor
|
||||
from bitstream import IdentityEncoder, ArithmeticEncoder, Bzip2Encoder
|
||||
from bitstream import IdentityEncoder, ArithmeticEncoder, Bzip2Encoder, BinomialHuffmanEncoder
|
||||
import wandb
|
||||
from pycallgraph2 import PyCallGraph
|
||||
from pycallgraph2.output import GraphvizOutput
|
||||
@ -58,6 +58,8 @@ class SpikeRunner(Slate_Runner):
|
||||
self.projector = LatentRNNProjector(latent_size=latent_size, input_size=input_size, **slate.consume(config, 'latent_projector', expand=True)).to(device)
|
||||
elif latent_projector_type == 'fourier':
|
||||
self.projector = LatentFourierProjector(latent_size=latent_size, input_size=input_size, **slate.consume(config, 'latent_projector', expand=True)).to(device)
|
||||
else:
|
||||
raise Exception('No such Latent Projector')
|
||||
|
||||
self.middle_out = MiddleOut(latent_size=latent_size, region_latent_size=region_latent_size, num_peers=self.num_peers, **slate.consume(config, 'middle_out', expand=True)).to(device)
|
||||
self.predictor = Predictor(region_latent_size=region_latent_size, **slate.consume(config, 'predictor', expand=True)).to(device)
|
||||
@ -84,6 +86,10 @@ class SpikeRunner(Slate_Runner):
|
||||
self.encoder = ArithmeticEncoder()
|
||||
elif bitstream_type == 'bzip2':
|
||||
self.encoder = Bzip2Encoder()
|
||||
elif bitstream_type == 'binomHuffman':
|
||||
self.encoder = BinomialHuffmanEncoder()
|
||||
else:
|
||||
raise Exception('No such Encoder')
|
||||
|
||||
# Optimizer
|
||||
self.optimizer = torch.optim.Adam(list(self.projector.parameters()) + list(self.middle_out.parameters()) + list(self.predictor.parameters()), lr=self.learning_rate)
|
||||
@ -196,23 +202,16 @@ class SpikeRunner(Slate_Runner):
|
||||
self.save_models(epoch)
|
||||
print(f'Evaluation complete for epoch {epoch + 1}')
|
||||
|
||||
|
||||
wandb.log({"epoch": epoch, "loss": total_loss}, step=epoch)
|
||||
print(f'Epoch {epoch + 1}/{self.epochs}, Loss: {total_loss}')
|
||||
|
||||
if (epoch + 1) % self.eval_freq == 0:
|
||||
print(f'Starting evaluation for epoch {epoch + 1}')
|
||||
test_loss = self.evaluate_model(epoch)
|
||||
if test_loss < best_test_score:
|
||||
best_test_score = test_loss
|
||||
self.save_models(epoch)
|
||||
print(f'Evaluation complete for epoch {epoch + 1}')
|
||||
|
||||
|
||||
def evaluate_model(self, epoch):
|
||||
print('Evaluating model...')
|
||||
device = self.device
|
||||
|
||||
# Save the current mode of the models
|
||||
projector_mode = self.projector.training
|
||||
middle_out_mode = self.middle_out.training
|
||||
predictor_mode = self.predictor.training
|
||||
|
||||
# Set models to evaluation mode
|
||||
self.projector.eval()
|
||||
self.middle_out.eval()
|
||||
self.predictor.eval()
|
||||
@ -226,84 +225,91 @@ class SpikeRunner(Slate_Runner):
|
||||
total_sequences = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for lead_idx in range(len(self.test_data[:8])):
|
||||
lead_data = self.test_data[lead_idx]
|
||||
true_data = []
|
||||
predicted_data = []
|
||||
delta_data = []
|
||||
targets = []
|
||||
min_length = min([len(seq) for seq in self.test_data])
|
||||
|
||||
min_length = min([len(seq) for seq in self.test_data])
|
||||
errs = []
|
||||
rels = []
|
||||
derrs = []
|
||||
|
||||
for lead_idx in range(len(self.test_data)):
|
||||
lead_data = self.test_data[lead_idx][:min_length]
|
||||
|
||||
indices = list(range(len(self.test_data)))
|
||||
random.shuffle(indices)
|
||||
|
||||
# Initialize lists to store segments and peer metrics
|
||||
stacked_segments = []
|
||||
peer_metrics = []
|
||||
targets = []
|
||||
lasts = []
|
||||
|
||||
for i in range(0, len(lead_data) - self.input_size-1, self.input_size // 8):
|
||||
for i in range(0, len(lead_data) - self.input_size - 1, self.input_size // 8):
|
||||
lead_segment = lead_data[i:i + self.input_size]
|
||||
inputs = torch.tensor(lead_segment, dtype=torch.float32).to(device)
|
||||
|
||||
# Collect peer segments and metrics
|
||||
peer_segments = []
|
||||
for peer_idx in self.sorted_peer_indices[lead_idx]:
|
||||
peer_segment = self.test_data[peer_idx][i:i + self.input_size][:min_length]
|
||||
peer_segment = self.test_data[peer_idx][:min_length][i:i + self.input_size]
|
||||
peer_segments.append(torch.tensor(peer_segment, dtype=torch.float32).to(device))
|
||||
peer_metric = torch.tensor([self.topology_matrix[lead_idx, peer_idx] for peer_idx in self.sorted_peer_indices[lead_idx]], dtype=torch.float32).to(device)
|
||||
peer_metrics.append(peer_metric)
|
||||
|
||||
# Stack segments to form the batch
|
||||
stacked_segment = torch.stack([inputs] + peer_segments).to(device)
|
||||
stacked_segments.append(stacked_segment)
|
||||
target = lead_data[i + self.input_size + 1]
|
||||
targets.append(target)
|
||||
last = lead_data[i + self.input_size]
|
||||
lasts.append(last)
|
||||
|
||||
# Pass the batch through the projector
|
||||
latents = self.projector(torch.stack(stacked_segments))
|
||||
latents = self.projector(torch.stack(stacked_segments) / self.value_scale)
|
||||
|
||||
my_latents = latents[:, 0, :]
|
||||
my_latent = latents[:, 0, :]
|
||||
peer_latents = latents[:, 1:, :]
|
||||
|
||||
# Pass through MiddleOut
|
||||
new_latents = self.middle_out(my_latents, peer_latents, torch.stack(peer_metrics))
|
||||
region_latent = self.middle_out(my_latent, peer_latents, torch.stack(peer_metrics))
|
||||
prediction = self.predictor(region_latent) * self.value_scale
|
||||
|
||||
# Predict using the predictor
|
||||
predictions = self.predictor(new_latents)
|
||||
tar = torch.tensor(targets, dtype=torch.float32).unsqueeze(-1).to(device)
|
||||
las = torch.tensor(lasts, dtype=torch.float32).unsqueeze(-1).numpy()
|
||||
loss = self.criterion(prediction, tar)
|
||||
err = np.sum(np.abs(prediction.cpu().detach().numpy() - tar.cpu().detach().numpy()))
|
||||
derr = np.sum(np.abs(las - tar.cpu().detach().numpy()))
|
||||
rel = err / np.sum(tar.cpu().detach().numpy())
|
||||
total_loss += loss.item()
|
||||
derrs.append(derr / np.prod(tar.size()).item())
|
||||
errs.append(err / np.prod(tar.size()).item())
|
||||
rels.append(rel.item())
|
||||
|
||||
# Compute loss and store true and predicted data
|
||||
for i, segment in enumerate(stacked_segments):
|
||||
for t in range(self.input_size):
|
||||
target = torch.tensor(targets[i])
|
||||
true_data.append(target.cpu().numpy())
|
||||
predicted_data.append(predictions[i].cpu().numpy())
|
||||
delta_data.append((target - predictions[i]).cpu().numpy())
|
||||
|
||||
loss = self.criterion(predictions[i].cpu(), target)
|
||||
total_loss += loss.item()
|
||||
|
||||
# Append true and predicted data for this lead sequence
|
||||
all_true.append(true_data)
|
||||
all_predicted.append(predicted_data)
|
||||
all_deltas.append(delta_data)
|
||||
all_true.extend(tar.cpu().numpy())
|
||||
all_predicted.extend(prediction.cpu().numpy())
|
||||
all_deltas.extend((tar.cpu().numpy() - prediction.cpu().numpy()).tolist())
|
||||
|
||||
if self.full_compression:
|
||||
# Bitstream encoding
|
||||
self.encoder.build_model(my_latents.cpu().numpy())
|
||||
compressed_data = self.encoder.encode(my_latents.cpu().numpy())
|
||||
decompressed_data = self.encoder.decode(compressed_data, len(my_latents))
|
||||
compression_ratio = len(my_latents) / len(compressed_data)
|
||||
self.encoder.build_model(my_latent.cpu().numpy())
|
||||
compressed_data = self.encoder.encode(my_latent.cpu().numpy())
|
||||
decompressed_data = self.encoder.decode(compressed_data, len(my_latent))
|
||||
compression_ratio = len(my_latent) / len(compressed_data)
|
||||
compression_ratios.append(compression_ratio)
|
||||
|
||||
# Check if decompressed data matches the original data
|
||||
if np.allclose(my_latents.cpu().numpy(), decompressed_data, atol=1e-5):
|
||||
if np.allclose(my_latent.cpu().numpy(), decompressed_data, atol=1e-5):
|
||||
exact_matches += 1
|
||||
total_sequences += 1
|
||||
|
||||
avg_loss = total_loss / len(self.test_data)
|
||||
print(f'Epoch {epoch+1}, Evaluation Loss: {avg_loss}')
|
||||
wandb.log({"evaluation_loss": avg_loss}, step=epoch)
|
||||
tot_err = sum(errs) / len(errs)
|
||||
tot_derr = sum(derrs) / len(derrs)
|
||||
adv_delta = tot_derr / tot_err
|
||||
approx_ratio = 1 / (sum(rels) / len(rels))
|
||||
|
||||
# Visualize delta distribution
|
||||
delta_plot_path = plot_delta_distribution(np.concatenate(all_deltas), epoch)
|
||||
print(f'Epoch {epoch+1}, Evaluation Loss: {avg_loss}')
|
||||
wandb.log({"evaluation_loss": avg_loss, "err": tot_err, "approx_ratio": approx_ratio, "adv_delta": adv_delta}, step=epoch)
|
||||
|
||||
# Visualize predictions
|
||||
visualize_prediction(all_true, all_predicted, all_deltas, epoch=epoch, num_points=1953, name='0.1s')
|
||||
visualize_prediction(all_true, all_predicted, all_deltas, epoch=epoch, num_points=195, name='0.01s')
|
||||
visualize_prediction(all_true, all_predicted, all_deltas, epoch=epoch, num_points=20, name='0.001s')
|
||||
|
||||
# Plot delta distribution
|
||||
delta_plot_path = plot_delta_distribution(np.array(all_deltas), epoch)
|
||||
wandb.log({"delta_distribution": wandb.Image(delta_plot_path)}, step=epoch)
|
||||
|
||||
if self.full_compression:
|
||||
@ -314,6 +320,22 @@ class SpikeRunner(Slate_Runner):
|
||||
wandb.log({"average_compression_ratio": avg_compression_ratio}, step=epoch)
|
||||
wandb.log({"exact_match_percentage": exact_match_percentage}, step=epoch)
|
||||
|
||||
# Restore the original mode of the models
|
||||
if projector_mode:
|
||||
self.projector.train()
|
||||
else:
|
||||
self.projector.eval()
|
||||
|
||||
if middle_out_mode:
|
||||
self.middle_out.train()
|
||||
else:
|
||||
self.middle_out.eval()
|
||||
|
||||
if predictor_mode:
|
||||
self.predictor.train()
|
||||
else:
|
||||
self.predictor.eval()
|
||||
|
||||
print('Evaluation done for this epoch.')
|
||||
return avg_loss
|
||||
|
||||
|
24
utils.py
24
utils.py
@ -14,39 +14,49 @@ def visualize_wav_data(sample_rate, data, title="WAV Data", num_points=None):
|
||||
plt.ylabel('Amplitude')
|
||||
plt.show()
|
||||
|
||||
def visualize_prediction(true_data, predicted_data, delta_data, sample_rate, num_points=None, epoch=None):
|
||||
"""Visualize the true data, predicted data, and deltas."""
|
||||
def visualize_prediction(true_data, predicted_data, delta_data, num_points=None, epoch=None, name=''):
|
||||
"""Visualize the true data, predicted data, deltas, and combined plot."""
|
||||
if num_points:
|
||||
true_data = true_data[:num_points]
|
||||
predicted_data = predicted_data[:num_points]
|
||||
delta_data = delta_data[:num_points]
|
||||
|
||||
plt.figure(figsize=(15, 5))
|
||||
plt.figure(figsize=(20, 5))
|
||||
|
||||
plt.subplot(3, 1, 1)
|
||||
plt.subplot(2, 2, 1)
|
||||
plt.plot(true_data, label='True Data')
|
||||
plt.title('True Data')
|
||||
plt.xlabel('Sample')
|
||||
plt.ylabel('Amplitude')
|
||||
|
||||
plt.subplot(3, 1, 2)
|
||||
plt.subplot(2, 2, 3)
|
||||
plt.plot(predicted_data, label='Predicted Data', color='orange')
|
||||
plt.title('Predicted Data')
|
||||
plt.xlabel('Sample')
|
||||
plt.ylabel('Amplitude')
|
||||
|
||||
plt.subplot(3, 1, 3)
|
||||
plt.subplot(2, 2, 4)
|
||||
plt.plot(delta_data, label='Delta', color='red')
|
||||
plt.title('Delta')
|
||||
plt.xlabel('Sample')
|
||||
plt.ylabel('Amplitude')
|
||||
|
||||
plt.subplot(2, 2, 2)
|
||||
plt.plot(true_data, label='True Data')
|
||||
plt.plot(predicted_data, label='Predicted Data', color='orange')
|
||||
plt.plot(delta_data, label='Delta', color='red')
|
||||
plt.title('Combined Data')
|
||||
plt.xlabel('Sample')
|
||||
plt.ylabel('Amplitude')
|
||||
plt.legend()
|
||||
|
||||
plt.tight_layout()
|
||||
tmp_dir = os.getenv('TMPDIR', '/tmp')
|
||||
file_path = os.path.join(tmp_dir, f'prediction_plot_{np.random.randint(1e6)}.png')
|
||||
plt.savefig(file_path)
|
||||
plt.close()
|
||||
wandb.log({"Prediction vs True Data": wandb.Image(file_path)}, step=epoch)
|
||||
wandb.log({f"Prediction vs True Data {name}": wandb.Image(file_path)}, step=epoch)
|
||||
|
||||
|
||||
def plot_delta_distribution(deltas, epoch):
|
||||
"""Plot the distribution of deltas."""
|
||||
|
Loading…
Reference in New Issue
Block a user