diff --git a/utils.py b/utils.py index df0d395..c30f353 100644 --- a/utils.py +++ b/utils.py @@ -1,6 +1,6 @@ import matplotlib.pyplot as plt +from scipy.stats import norm import numpy as np -import wandb import os def visualize_wav_data(sample_rate, data, title="WAV Data", num_points=None): @@ -57,18 +57,24 @@ def visualize_prediction_grid(true_data, predicted_data, delta_data, num_points= plt.close() return file_path -def visualize_prediction(true_data, predicted_data, delta_data, num_points=None, epoch=None): +def visualize_prediction(true_data, predicted_data, delta_data, steps_data, num_points=None, epoch=None): """Visualize the combined plot of true data, predicted data, and deltas.""" if num_points: true_data = true_data[:num_points] predicted_data = predicted_data[:num_points] delta_data = delta_data[:num_points] + steps_data = steps_data[:num_points] plt.figure(figsize=(20, 10)) plt.plot(true_data, label='True Data') plt.plot(predicted_data, label='Predicted Data', color='orange') plt.plot(delta_data, label='Delta', color='red') + plt.plot(steps_data, label='Naive', color='darkred', linestyle='--') + + # Add horizontal line at y=0 + plt.axhline(y=0, color='gray', linestyle='--', linewidth=1) + plt.title('Combined Data') plt.xlabel('Sample') plt.ylabel('Amplitude') @@ -81,16 +87,30 @@ def visualize_prediction(true_data, predicted_data, delta_data, num_points=None, plt.close() return file_path + def plot_delta_distribution(deltas, epoch): - """Plot the distribution of deltas.""" + mu, std = 0, np.std(deltas) + + # Create histogram plt.figure(figsize=(10, 6)) - plt.hist(deltas, bins=100, density=True, alpha=0.6, color='g') + count, bins, ignored = plt.hist(deltas, bins=min(100, np.max(deltas) - np.min(deltas)), density=True, alpha=0.6, color='g') + + # Plot Gaussian curve + xmin, xmax = plt.xlim() + x = np.linspace(xmin, xmax, 100) + p = norm.pdf(x, mu, std) + plt.plot(x, p, 'k', linewidth=2) + + # Add title and labels plt.title(f'Delta Distribution at Epoch {epoch}') plt.xlabel('Delta') plt.ylabel('Density') plt.grid(True) + + # Save the plot tmp_dir = os.getenv('TMPDIR', '/tmp') file_path = os.path.join(tmp_dir, f'delta_distribution_epoch_{epoch}_{np.random.randint(1e6)}.png') plt.savefig(file_path) plt.close() - return file_path + + return file_path \ No newline at end of file