2024-05-24 22:01:59 +02:00
|
|
|
import matplotlib.pyplot as plt
|
2024-05-29 21:12:50 +02:00
|
|
|
from scipy.stats import norm
|
2024-05-24 22:01:59 +02:00
|
|
|
import numpy as np
|
|
|
|
import os
|
|
|
|
|
|
|
|
def visualize_wav_data(sample_rate, data, title="WAV Data", num_points=None):
|
|
|
|
"""Visualize WAV data using matplotlib."""
|
|
|
|
if num_points:
|
|
|
|
data = data[:num_points]
|
|
|
|
plt.figure(figsize=(10, 4))
|
|
|
|
plt.plot(np.linspace(0, len(data) / sample_rate, num=len(data)), data)
|
|
|
|
plt.title(title)
|
|
|
|
plt.xlabel('Time [s]')
|
|
|
|
plt.ylabel('Amplitude')
|
|
|
|
plt.show()
|
|
|
|
|
2024-05-28 12:53:00 +02:00
|
|
|
def visualize_prediction_grid(true_data, predicted_data, delta_data, num_points=None, epoch=None):
|
2024-05-27 10:28:51 +02:00
|
|
|
"""Visualize the true data, predicted data, deltas, and combined plot."""
|
2024-05-24 22:01:59 +02:00
|
|
|
if num_points:
|
|
|
|
true_data = true_data[:num_points]
|
|
|
|
predicted_data = predicted_data[:num_points]
|
|
|
|
delta_data = delta_data[:num_points]
|
|
|
|
|
2024-05-27 10:28:51 +02:00
|
|
|
plt.figure(figsize=(20, 5))
|
2024-05-24 22:01:59 +02:00
|
|
|
|
2024-05-27 10:28:51 +02:00
|
|
|
plt.subplot(2, 2, 1)
|
2024-05-24 22:01:59 +02:00
|
|
|
plt.plot(true_data, label='True Data')
|
|
|
|
plt.title('True Data')
|
|
|
|
plt.xlabel('Sample')
|
|
|
|
plt.ylabel('Amplitude')
|
|
|
|
|
2024-05-27 10:28:51 +02:00
|
|
|
plt.subplot(2, 2, 3)
|
2024-05-24 22:01:59 +02:00
|
|
|
plt.plot(predicted_data, label='Predicted Data', color='orange')
|
|
|
|
plt.title('Predicted Data')
|
|
|
|
plt.xlabel('Sample')
|
|
|
|
plt.ylabel('Amplitude')
|
|
|
|
|
2024-05-27 10:28:51 +02:00
|
|
|
plt.subplot(2, 2, 4)
|
2024-05-24 22:01:59 +02:00
|
|
|
plt.plot(delta_data, label='Delta', color='red')
|
|
|
|
plt.title('Delta')
|
|
|
|
plt.xlabel('Sample')
|
|
|
|
plt.ylabel('Amplitude')
|
|
|
|
|
2024-05-27 10:28:51 +02:00
|
|
|
plt.subplot(2, 2, 2)
|
|
|
|
plt.plot(true_data, label='True Data')
|
|
|
|
plt.plot(predicted_data, label='Predicted Data', color='orange')
|
|
|
|
plt.plot(delta_data, label='Delta', color='red')
|
|
|
|
plt.title('Combined Data')
|
|
|
|
plt.xlabel('Sample')
|
|
|
|
plt.ylabel('Amplitude')
|
|
|
|
plt.legend()
|
|
|
|
|
2024-05-24 22:01:59 +02:00
|
|
|
plt.tight_layout()
|
|
|
|
tmp_dir = os.getenv('TMPDIR', '/tmp')
|
|
|
|
file_path = os.path.join(tmp_dir, f'prediction_plot_{np.random.randint(1e6)}.png')
|
|
|
|
plt.savefig(file_path)
|
|
|
|
plt.close()
|
2024-05-28 12:53:00 +02:00
|
|
|
return file_path
|
|
|
|
|
2024-05-29 21:12:50 +02:00
|
|
|
def visualize_prediction(true_data, predicted_data, delta_data, steps_data, num_points=None, epoch=None):
|
2024-05-28 12:53:00 +02:00
|
|
|
"""Visualize the combined plot of true data, predicted data, and deltas."""
|
|
|
|
if num_points:
|
|
|
|
true_data = true_data[:num_points]
|
|
|
|
predicted_data = predicted_data[:num_points]
|
|
|
|
delta_data = delta_data[:num_points]
|
2024-05-29 21:12:50 +02:00
|
|
|
steps_data = steps_data[:num_points]
|
2024-05-28 12:53:00 +02:00
|
|
|
|
|
|
|
plt.figure(figsize=(20, 10))
|
|
|
|
|
|
|
|
plt.plot(true_data, label='True Data')
|
|
|
|
plt.plot(predicted_data, label='Predicted Data', color='orange')
|
|
|
|
plt.plot(delta_data, label='Delta', color='red')
|
2024-05-29 21:12:50 +02:00
|
|
|
plt.plot(steps_data, label='Naive', color='darkred', linestyle='--')
|
|
|
|
|
|
|
|
# Add horizontal line at y=0
|
|
|
|
plt.axhline(y=0, color='gray', linestyle='--', linewidth=1)
|
|
|
|
|
2024-05-28 12:53:00 +02:00
|
|
|
plt.title('Combined Data')
|
|
|
|
plt.xlabel('Sample')
|
|
|
|
plt.ylabel('Amplitude')
|
|
|
|
plt.legend()
|
2024-05-27 10:28:51 +02:00
|
|
|
|
2024-05-28 12:53:00 +02:00
|
|
|
plt.tight_layout()
|
|
|
|
tmp_dir = os.getenv('TMPDIR', '/tmp')
|
|
|
|
file_path = os.path.join(tmp_dir, f'prediction_plot_{np.random.randint(1e6)}.png')
|
|
|
|
plt.savefig(file_path)
|
|
|
|
plt.close()
|
|
|
|
return file_path
|
2024-05-24 22:01:59 +02:00
|
|
|
|
2024-05-29 21:12:50 +02:00
|
|
|
|
2024-05-24 22:01:59 +02:00
|
|
|
def plot_delta_distribution(deltas, epoch):
|
2024-05-29 21:12:50 +02:00
|
|
|
mu, std = 0, np.std(deltas)
|
|
|
|
|
|
|
|
# Create histogram
|
2024-05-24 22:01:59 +02:00
|
|
|
plt.figure(figsize=(10, 6))
|
2024-05-29 21:12:50 +02:00
|
|
|
count, bins, ignored = plt.hist(deltas, bins=min(100, np.max(deltas) - np.min(deltas)), density=True, alpha=0.6, color='g')
|
|
|
|
|
|
|
|
# Plot Gaussian curve
|
|
|
|
xmin, xmax = plt.xlim()
|
|
|
|
x = np.linspace(xmin, xmax, 100)
|
|
|
|
p = norm.pdf(x, mu, std)
|
|
|
|
plt.plot(x, p, 'k', linewidth=2)
|
|
|
|
|
|
|
|
# Add title and labels
|
2024-05-24 22:01:59 +02:00
|
|
|
plt.title(f'Delta Distribution at Epoch {epoch}')
|
|
|
|
plt.xlabel('Delta')
|
|
|
|
plt.ylabel('Density')
|
|
|
|
plt.grid(True)
|
2024-05-29 21:12:50 +02:00
|
|
|
|
|
|
|
# Save the plot
|
2024-05-24 22:01:59 +02:00
|
|
|
tmp_dir = os.getenv('TMPDIR', '/tmp')
|
|
|
|
file_path = os.path.join(tmp_dir, f'delta_distribution_epoch_{epoch}_{np.random.randint(1e6)}.png')
|
|
|
|
plt.savefig(file_path)
|
|
|
|
plt.close()
|
2024-05-29 21:12:50 +02:00
|
|
|
|
|
|
|
return file_path
|