Nicer graphs

This commit is contained in:
Dominik Moritz Roth 2024-05-29 21:12:50 +02:00
parent 404b59c8ba
commit 032099f432

View File

@ -1,6 +1,6 @@
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from scipy.stats import norm
import numpy as np import numpy as np
import wandb
import os import os
def visualize_wav_data(sample_rate, data, title="WAV Data", num_points=None): def visualize_wav_data(sample_rate, data, title="WAV Data", num_points=None):
@ -57,18 +57,24 @@ def visualize_prediction_grid(true_data, predicted_data, delta_data, num_points=
plt.close() plt.close()
return file_path return file_path
def visualize_prediction(true_data, predicted_data, delta_data, num_points=None, epoch=None): def visualize_prediction(true_data, predicted_data, delta_data, steps_data, num_points=None, epoch=None):
"""Visualize the combined plot of true data, predicted data, and deltas.""" """Visualize the combined plot of true data, predicted data, and deltas."""
if num_points: if num_points:
true_data = true_data[:num_points] true_data = true_data[:num_points]
predicted_data = predicted_data[:num_points] predicted_data = predicted_data[:num_points]
delta_data = delta_data[:num_points] delta_data = delta_data[:num_points]
steps_data = steps_data[:num_points]
plt.figure(figsize=(20, 10)) plt.figure(figsize=(20, 10))
plt.plot(true_data, label='True Data') plt.plot(true_data, label='True Data')
plt.plot(predicted_data, label='Predicted Data', color='orange') plt.plot(predicted_data, label='Predicted Data', color='orange')
plt.plot(delta_data, label='Delta', color='red') plt.plot(delta_data, label='Delta', color='red')
plt.plot(steps_data, label='Naive', color='darkred', linestyle='--')
# Add horizontal line at y=0
plt.axhline(y=0, color='gray', linestyle='--', linewidth=1)
plt.title('Combined Data') plt.title('Combined Data')
plt.xlabel('Sample') plt.xlabel('Sample')
plt.ylabel('Amplitude') plt.ylabel('Amplitude')
@ -81,16 +87,30 @@ def visualize_prediction(true_data, predicted_data, delta_data, num_points=None,
plt.close() plt.close()
return file_path return file_path
def plot_delta_distribution(deltas, epoch): def plot_delta_distribution(deltas, epoch):
"""Plot the distribution of deltas.""" mu, std = 0, np.std(deltas)
# Create histogram
plt.figure(figsize=(10, 6)) plt.figure(figsize=(10, 6))
plt.hist(deltas, bins=100, density=True, alpha=0.6, color='g') count, bins, ignored = plt.hist(deltas, bins=min(100, np.max(deltas) - np.min(deltas)), density=True, alpha=0.6, color='g')
# Plot Gaussian curve
xmin, xmax = plt.xlim()
x = np.linspace(xmin, xmax, 100)
p = norm.pdf(x, mu, std)
plt.plot(x, p, 'k', linewidth=2)
# Add title and labels
plt.title(f'Delta Distribution at Epoch {epoch}') plt.title(f'Delta Distribution at Epoch {epoch}')
plt.xlabel('Delta') plt.xlabel('Delta')
plt.ylabel('Density') plt.ylabel('Density')
plt.grid(True) plt.grid(True)
# Save the plot
tmp_dir = os.getenv('TMPDIR', '/tmp') tmp_dir = os.getenv('TMPDIR', '/tmp')
file_path = os.path.join(tmp_dir, f'delta_distribution_epoch_{epoch}_{np.random.randint(1e6)}.png') file_path = os.path.join(tmp_dir, f'delta_distribution_epoch_{epoch}_{np.random.randint(1e6)}.png')
plt.savefig(file_path) plt.savefig(file_path)
plt.close() plt.close()
return file_path return file_path