diff --git a/utils.py b/utils.py index f8f7f93..df0d395 100644 --- a/utils.py +++ b/utils.py @@ -14,7 +14,7 @@ def visualize_wav_data(sample_rate, data, title="WAV Data", num_points=None): plt.ylabel('Amplitude') plt.show() -def visualize_prediction(true_data, predicted_data, delta_data, num_points=None, epoch=None, name=''): +def visualize_prediction_grid(true_data, predicted_data, delta_data, num_points=None, epoch=None): """Visualize the true data, predicted data, deltas, and combined plot.""" if num_points: true_data = true_data[:num_points] @@ -55,8 +55,31 @@ def visualize_prediction(true_data, predicted_data, delta_data, num_points=None, file_path = os.path.join(tmp_dir, f'prediction_plot_{np.random.randint(1e6)}.png') plt.savefig(file_path) plt.close() - wandb.log({f"Prediction vs True Data {name}": wandb.Image(file_path)}, step=epoch) + return file_path +def visualize_prediction(true_data, predicted_data, delta_data, num_points=None, epoch=None): + """Visualize the combined plot of true data, predicted data, and deltas.""" + if num_points: + true_data = true_data[:num_points] + predicted_data = predicted_data[:num_points] + delta_data = delta_data[:num_points] + + plt.figure(figsize=(20, 10)) + + plt.plot(true_data, label='True Data') + plt.plot(predicted_data, label='Predicted Data', color='orange') + plt.plot(delta_data, label='Delta', color='red') + plt.title('Combined Data') + plt.xlabel('Sample') + plt.ylabel('Amplitude') + plt.legend() + + plt.tight_layout() + tmp_dir = os.getenv('TMPDIR', '/tmp') + file_path = os.path.join(tmp_dir, f'prediction_plot_{np.random.randint(1e6)}.png') + plt.savefig(file_path) + plt.close() + return file_path def plot_delta_distribution(deltas, epoch): """Plot the distribution of deltas."""