Spikey/analysis.ipynb

615 lines
26 KiB
Plaintext
Raw Permalink Normal View History

2024-05-25 17:31:08 +02:00
{
"cells": [
{
"cell_type": "code",
2024-05-25 17:33:19 +02:00
"execution_count": null,
"id": "bc4aa0ae",
2024-05-25 17:31:08 +02:00
"metadata": {
"scrolled": false
},
2024-05-25 17:33:19 +02:00
"outputs": [],
2024-05-25 17:31:08 +02:00
"source": [
"import os\n",
"import numpy as np\n",
"import plotly.graph_objects as go\n",
"import plotly.express as px\n",
"from scipy.io import wavfile\n",
"from data_processing import load_all_wavs\n",
"import random\n",
"\n",
"# Constants\n",
"SAMPLE_RATE = 19531\n",
"LEAD_LIMIT = 256\n",
"TIME_LIMIT_SEC = 0.1\n",
"DATA_DIR = \"data\"\n",
"\n",
"# Utility functions for visualizations\n",
"\n",
"def load_and_trim_data(data_dir, sample_rate, lead_limit, time_limit_sec):\n",
" \"\"\"Load and trim the data to specified lead and time limits.\"\"\"\n",
" all_data = load_all_wavs(data_dir)\n",
" trimmed_data = []\n",
" max_samples = int(sample_rate * time_limit_sec)\n",
" for lead in all_data[:lead_limit]:\n",
" trimmed_data.append(lead[:max_samples])\n",
" return np.array(trimmed_data)\n",
"\n",
"def plot_individual_leads(data, sample_rate):\n",
" \"\"\"Plot individual leads using Plotly for interactive zoom.\"\"\"\n",
" fig = go.Figure()\n",
" for i, lead in enumerate(data):\n",
" fig.add_trace(go.Scatter(\n",
" x=np.linspace(0, len(lead) / sample_rate, num=len(lead)),\n",
" y=lead,\n",
" mode='lines',\n",
" name=f'Lead {i+1}'\n",
" ))\n",
" fig.update_layout(\n",
" title='Individual Leads',\n",
" xaxis_title='Time [s]',\n",
" yaxis_title='Amplitude'\n",
" )\n",
" return fig\n",
"\n",
"def plot_lead_correlations(data):\n",
" \"\"\"Plot correlation between leads using Plotly.\"\"\"\n",
" correlations = np.corrcoef(data)\n",
" fig = px.imshow(correlations,\n",
" labels=dict(color=\"Correlation\"),\n",
" x=[f'Lead {i+1}' for i in range(len(data))],\n",
" y=[f'Lead {i+1}' for i in range(len(data))],\n",
" color_continuous_scale='RdBu_r',\n",
" zmin=-1, zmax=1)\n",
" fig.update_layout(\n",
" title='Correlation Between Leads'\n",
" )\n",
" return fig\n",
"\n",
"def plot_highly_correlated_leads(data, threshold=0.75):\n",
" \"\"\"Plot leads that have high correlation with each other using Plotly.\"\"\"\n",
" correlations = np.corrcoef(data)\n",
" high_corr_pairs = np.argwhere(correlations > threshold)\n",
" grouped_pairs = {}\n",
" \n",
" for (i, j) in high_corr_pairs:\n",
" if i >= j:\n",
" continue\n",
" if i not in grouped_pairs:\n",
" grouped_pairs[i] = []\n",
" grouped_pairs[i].append(j)\n",
"\n",
" figs = []\n",
" for i, group in grouped_pairs.items():\n",
" fig = go.Figure()\n",
" fig.add_trace(go.Scatter(\n",
" x=np.arange(len(data[i])),\n",
" y=data[i],\n",
" mode='lines',\n",
" name=f'Lead {i+1}'\n",
" ))\n",
" for j in group:\n",
" fig.add_trace(go.Scatter(\n",
" x=np.arange(len(data[j])),\n",
" y=data[j],\n",
" mode='lines',\n",
" name=f'Lead {j+1}',\n",
" line=dict(dash='dash')\n",
" ))\n",
" fig.update_layout(\n",
" title=f'Highly Correlated Leads Group {i+1}',\n",
" xaxis_title='Sample',\n",
" yaxis_title='Amplitude'\n",
" )\n",
" figs.append(fig)\n",
" return figs\n",
"\n",
"def plot_top_correlated_pairs(data, top_n=3):\n",
" \"\"\"Plot the top N most highly correlated pairs of leads and their correlations over time.\"\"\"\n",
" correlations = np.corrcoef(data)\n",
" np.fill_diagonal(correlations, 0) # Ignore self-correlations\n",
" top_pairs = np.unravel_index(np.argsort(correlations.ravel())[-top_n:], correlations.shape)\n",
" top_pairs = list(zip(top_pairs[0], top_pairs[1]))\n",
"\n",
" figs = []\n",
" for i, (lead1, lead2) in enumerate(top_pairs):\n",
" fig = go.Figure()\n",
" fig.add_trace(go.Scatter(\n",
" x=np.arange(len(data[lead1])),\n",
" y=data[lead1],\n",
" mode='lines',\n",
" name=f'Lead {lead1+1}'\n",
" ))\n",
" fig.add_trace(go.Scatter(\n",
" x=np.arange(len(data[lead2])),\n",
" y=data[lead2],\n",
" mode='lines',\n",
" name=f'Lead {lead2+1}',\n",
" line=dict(dash='dash')\n",
" ))\n",
" fig.update_layout(\n",
" title=f'Top Correlated Pair {i+1}: Lead {lead1+1} and Lead {lead2+1}',\n",
" xaxis_title='Sample',\n",
" yaxis_title='Amplitude'\n",
" )\n",
" figs.append(fig)\n",
"\n",
" # Create correlation matrix\n",
" lead1_data = data[lead1]\n",
" lead2_data = data[lead2]\n",
" arr = np.stack((lead1_data, lead2_data), axis=1)\n",
" correlation_matrix = np.cov(arr) / np.max(lead1_data) / np.max(lead2_data)\n",
"\n",
" # Replace NaN values with 0\n",
" #correlation_matrix = np.nan_to_num(correlation_matrix)\n",
"\n",
" fig = px.imshow(correlation_matrix,\n",
" color_continuous_scale='RdBu_r')\n",
"\n",
" fig.update_layout(\n",
" title=f'Correlation Over Time for Lead {lead1+1} and Lead {lead2+1}',\n",
" )\n",
" figs.append(fig)\n",
" return figs\n",
"\n",
"# Load and trim the data\n",
"trimmed_data = load_and_trim_data(DATA_DIR, SAMPLE_RATE, LEAD_LIMIT, TIME_LIMIT_SEC)\n",
"\n",
"# Plot individual leads\n",
"fig_individual_leads = plot_individual_leads(trimmed_data, SAMPLE_RATE)\n",
"fig_individual_leads.show()\n",
"\n",
"# Plot lead correlations\n",
"fig_lead_correlations = plot_lead_correlations(trimmed_data)\n",
"fig_lead_correlations.show()\n",
"\n",
"# Plot highly correlated leads\n",
"figs_highly_correlated_leads = plot_highly_correlated_leads(trimmed_data)\n",
"for fig in figs_highly_correlated_leads:\n",
" fig.show()\n",
"\n",
"# Plot top correlated pairs and their correlations over time\n",
"#figs_top_correlated_pairs = plot_top_correlated_pairs(trimmed_data)\n",
"#for fig in figs_top_correlated_pairs:\n",
"# fig.show()\n"
]
},
{
"cell_type": "code",
2024-05-25 17:33:19 +02:00
"execution_count": null,
"id": "8257585d",
2024-05-25 17:31:08 +02:00
"metadata": {
"scrolled": false
},
2024-05-25 17:33:19 +02:00
"outputs": [],
2024-05-25 17:31:08 +02:00
"source": [
"import numpy as np\n",
"import plotly.graph_objects as go\n",
"from scipy.io import wavfile\n",
"from data_processing import load_all_wavs\n",
"import random\n",
"\n",
"# Constants\n",
"SAMPLE_RATE = 19531\n",
"LEAD_LIMIT = 128\n",
"TIME_LIMIT_SEC = 0.1\n",
"DATA_DIR = \"data\"\n",
"RECONSTRUCTION_START_FRACTION = 0.5 # Fraction of time at which to start reconstruction\n",
"\n",
"# Utility functions for visualizations\n",
"\n",
"def load_and_trim_data(data_dir, sample_rate, lead_limit, time_limit_sec):\n",
" \"\"\"Load and trim the data to specified lead and time limits.\"\"\"\n",
" all_data = load_all_wavs(data_dir)\n",
" trimmed_data = []\n",
" max_samples = int(sample_rate * time_limit_sec)\n",
" for lead in all_data[:lead_limit]:\n",
" trimmed_data.append(lead[:max_samples])\n",
" return np.array(trimmed_data)\n",
"\n",
"def plot_random_lead_with_reconstruction(data, reconstruction_start_fraction, top_n=8):\n",
" \"\"\"Plot a random lead with reconstruction based on correlations with other leads.\"\"\"\n",
" num_leads = data.shape[0]\n",
" random_lead_index = random.randint(0, num_leads - 1)\n",
" \n",
" # Get the random lead data\n",
" random_lead_data = data[random_lead_index]\n",
" \n",
" # Calculate correlations\n",
" correlations = np.corrcoef(data)\n",
" np.fill_diagonal(correlations, 0) # Ignore self-correlations\n",
" \n",
" # Find the top N most correlated leads\n",
" most_correlated_indices = np.argsort(correlations[random_lead_index])[-top_n:]\n",
" \n",
" # Plot the original random lead\n",
" fig = go.Figure()\n",
" time = np.linspace(0, len(random_lead_data) / SAMPLE_RATE, num=len(random_lead_data))\n",
" fig.add_trace(go.Scatter(\n",
" x=time,\n",
" y=random_lead_data,\n",
" mode='lines',\n",
" name=f'Original Lead {random_lead_index+1}'\n",
" ))\n",
" \n",
" # Perform reconstruction\n",
" reconstruction_start_index = int(len(random_lead_data) * reconstruction_start_fraction)\n",
" reconstructed_data = np.zeros_like(random_lead_data)\n",
" reconstructed_data[reconstruction_start_index] = random_lead_data[reconstruction_start_index-2]*0.25 + random_lead_data[reconstruction_start_index-1]*0.25 + random_lead_data[reconstruction_start_index]*0.5\n",
" \n",
" for i in range(reconstruction_start_index + 1, len(random_lead_data)):\n",
" for correlated_index in most_correlated_indices:\n",
" reconstructed_data[i] += correlations[random_lead_index, correlated_index] * \\\n",
" (data[correlated_index, i] - data[correlated_index, i-1])\n",
" reconstructed_data[i] += reconstructed_data[i-1]\n",
" \n",
" # Plot the reconstructed lead\n",
" fig.add_trace(go.Scatter(\n",
" x=time[reconstruction_start_index:],\n",
" y=reconstructed_data[reconstruction_start_index:],\n",
" mode='lines',\n",
" name=f'Reconstructed Lead {random_lead_index+1}',\n",
" line=dict(dash='dash')\n",
" ))\n",
" \n",
" fig.update_layout(\n",
" title=f'Lead {random_lead_index+1} with Reconstructed Data',\n",
" xaxis_title='Time [s]',\n",
" yaxis_title='Amplitude',\n",
" legend_title_text='Data'\n",
" )\n",
" fig.show()\n",
"\n",
"# Load and trim the data\n",
"trimmed_data = load_and_trim_data(DATA_DIR, SAMPLE_RATE, LEAD_LIMIT, TIME_LIMIT_SEC)\n",
"\n",
"# Plot random lead with reconstruction\n",
"for i in range(8):\n",
" plot_random_lead_with_reconstruction(trimmed_data, RECONSTRUCTION_START_FRACTION)\n"
]
},
{
"cell_type": "code",
2024-05-25 17:33:19 +02:00
"execution_count": null,
"id": "b48fafda",
2024-05-25 17:31:08 +02:00
"metadata": {},
2024-05-25 17:33:19 +02:00
"outputs": [],
2024-05-25 17:31:08 +02:00
"source": [
"import numpy as np\n",
"import plotly.graph_objects as go\n",
"from scipy.io import wavfile\n",
"from data_processing import load_all_wavs\n",
"import random\n",
"from sklearn.gaussian_process import GaussianProcessRegressor\n",
"from sklearn.gaussian_process.kernels import RBF, Matern, RationalQuadratic\n",
"from sklearn.preprocessing import StandardScaler\n",
"\n",
"# Constants\n",
"SAMPLE_RATE = 19531\n",
"LEAD_LIMIT = 128\n",
"TIME_LIMIT_SEC = 0.1\n",
"DATA_DIR = \"data\"\n",
"RECONSTRUCTION_START_FRACTION = 0.5 # Fraction of time at which to start reconstruction\n",
"KERNEL_TYPE = 'RBF' # Configurable kernel: 'RBF', 'Matern', 'RationalQuadratic'\n",
"ALPHA = 1e-6 # Small noise term to ensure positive definite kernel\n",
"NUM_GP_SAMPLES = 1 # Number of GP samples to draw\n",
"TOP_N = 3 # Number of most correlated leads to consider\n",
"GP_TRAINING_POINTS = 64 # Number of recent points to use for GP fitting\n",
"\n",
"# Utility functions for visualizations\n",
"\n",
"def load_and_trim_data(data_dir, sample_rate, lead_limit, time_limit_sec):\n",
" \"\"\"Load and trim the data to specified lead and time limits.\"\"\"\n",
" all_data = load_all_wavs(data_dir)\n",
" trimmed_data = []\n",
" max_samples = int(sample_rate * time_limit_sec)\n",
" for lead in all_data[:lead_limit]:\n",
" trimmed_data.append(lead[:max_samples])\n",
" return np.array(trimmed_data)\n",
"\n",
"def plot_combined_reconstruction(data, reconstruction_start_fraction, kernel_type, num_gp_samples, top_n, gp_training_points):\n",
" \"\"\"Plot a random lead with combined GP extrapolation and correlated lead contributions.\"\"\"\n",
" num_leads = data.shape[0]\n",
" random_lead_index = random.randint(0, num_leads - 1)\n",
" \n",
" # Get the random lead data\n",
" random_lead_data = data[random_lead_index]\n",
" \n",
" # Normalize the data\n",
" scaler = StandardScaler()\n",
" random_lead_data_normalized = scaler.fit_transform(random_lead_data.reshape(-1, 1)).flatten()\n",
" \n",
" # Calculate correlations\n",
" correlations = np.corrcoef(data)\n",
" np.fill_diagonal(correlations, 0) # Ignore self-correlations\n",
" \n",
" # Find the top N most correlated leads\n",
" most_correlated_indices = np.argsort(correlations[random_lead_index])[-top_n:]\n",
" \n",
" # Plot the original random lead\n",
" fig = go.Figure()\n",
" time = np.linspace(0, len(random_lead_data) / SAMPLE_RATE, num=len(random_lead_data))\n",
" fig.add_trace(go.Scatter(\n",
" x=time,\n",
" y=random_lead_data,\n",
" mode='lines',\n",
" name=f'Original Lead {random_lead_index+1}'\n",
" ))\n",
"\n",
" # Perform GP extrapolation\n",
" reconstruction_start_index = int(len(random_lead_data) * reconstruction_start_fraction)\n",
" \n",
" # Fit Gaussian Process\n",
" kernel = None\n",
" if kernel_type == 'RBF':\n",
" kernel = RBF(length_scale=1.0, length_scale_bounds=(1e-5, 1e5))\n",
" elif kernel_type == 'Matern':\n",
" kernel = Matern(length_scale=1.0, length_scale_bounds=(1e-5, 1e5))\n",
" elif kernel_type == 'RationalQuadratic':\n",
" kernel = RationalQuadratic(length_scale=1.0, length_scale_bounds=(1e-5, 1e5))\n",
" \n",
" gp = GaussianProcessRegressor(kernel=kernel, alpha=ALPHA)\n",
" train_time = time[reconstruction_start_index - gp_training_points:reconstruction_start_index].reshape(-1, 1)\n",
" train_data = random_lead_data_normalized[reconstruction_start_index - gp_training_points:reconstruction_start_index]\n",
" gp.fit(train_time, train_data)\n",
" \n",
" extrapolated_time = time[reconstruction_start_index:].reshape(-1, 1)\n",
" extrapolated_mean_normalized, extrapolated_std_normalized = gp.predict(extrapolated_time, return_std=True)\n",
"\n",
" # Inverse transform the normalized predictions\n",
" extrapolated_mean = scaler.inverse_transform(extrapolated_mean_normalized.reshape(-1, 1)).flatten()\n",
" extrapolated_std = extrapolated_std_normalized * scaler.scale_[0]\n",
"\n",
" # Initialize the combined reconstruction with GP mean\n",
" combined_reconstruction = np.copy(extrapolated_mean)*0 + random_lead_data[reconstruction_start_index-1] *0.33 + random_lead_data[reconstruction_start_index]*0.66\n",
"\n",
" # Calculate variances (inverse of weights)\n",
" gp_variance = extrapolated_std ** 2\n",
"\n",
" # Perform reconstruction\n",
" reconstruction_start_index = int(len(random_lead_data) * reconstruction_start_fraction)\n",
" reconstructed_data = np.zeros_like(random_lead_data)\n",
" reconstructed_data[reconstruction_start_index] = random_lead_data[reconstruction_start_index-2]*0.25 + random_lead_data[reconstruction_start_index-1]*0.25 + random_lead_data[reconstruction_start_index]*0.5\n",
" \n",
" for i in range(reconstruction_start_index + 1, len(random_lead_data)):\n",
" for correlated_index in most_correlated_indices:\n",
" reconstructed_data[i] += correlations[random_lead_index, correlated_index] * \\\n",
" (data[correlated_index, i] - data[correlated_index, i-1])\n",
" reconstructed_data[i] += reconstructed_data[i-1]\n",
"\n",
" #combined_reconstruction = reconstructed_data\n",
" combined_reconstruction = (extrapolated_mean / gp_variance + reconstructed_data[reconstruction_start_index:]) / (1 / gp_variance + 1)\n",
" \n",
" # Plot the GP extrapolated mean\n",
" #fig.add_trace(go.Scatter(\n",
" # x=time[reconstruction_start_index:],\n",
" # y=extrapolated_mean,\n",
" # mode='lines',\n",
" # name=f'GP Extrapolated Lead {random_lead_index+1}',\n",
" # line=dict(dash='dot')\n",
" #))\n",
"\n",
" # Plot the combined reconstruction\n",
" fig.add_trace(go.Scatter(\n",
" x=time[reconstruction_start_index:],\n",
" y=combined_reconstruction,\n",
" mode='lines',\n",
" name=f'Combined Reconstruction Lead {random_lead_index+1}',\n",
" line=dict(dash='dash')\n",
" ))\n",
"\n",
" fig.add_trace(go.Scatter(\n",
" x=time[reconstruction_start_index:],\n",
" y=reconstructed_data[reconstruction_start_index:],\n",
" mode='lines',\n",
" name=f'Peer Reconstruction Lead {random_lead_index+1}',\n",
" line=dict(dash='dot')\n",
" ))\n",
" \n",
" fig.update_layout(\n",
" title=f'Lead {random_lead_index+1} with Combined GP and Correlated Lead Contributions',\n",
" xaxis_title='Time [s]',\n",
" yaxis_title='Amplitude',\n",
" legend_title_text='Data'\n",
" )\n",
" fig.show()\n",
"\n",
"# Load and trim the data\n",
"trimmed_data = load_and_trim_data(DATA_DIR, SAMPLE_RATE, LEAD_LIMIT, TIME_LIMIT_SEC)\n",
"\n",
"# Plot combined reconstruction\n",
"plot_combined_reconstruction(trimmed_data, RECONSTRUCTION_START_FRACTION, KERNEL_TYPE, NUM_GP_SAMPLES, TOP_N, GP_TRAINING_POINTS)\n"
]
},
{
"cell_type": "code",
2024-05-25 17:33:19 +02:00
"execution_count": null,
"id": "0c8b6f77",
2024-05-25 17:31:08 +02:00
"metadata": {
"scrolled": true
},
2024-05-25 17:33:19 +02:00
"outputs": [],
2024-05-25 17:31:08 +02:00
"source": [
"import numpy as np\n",
"import plotly.graph_objects as go\n",
"from scipy.io import wavfile\n",
"from data_processing import load_all_wavs\n",
"import random\n",
"from sklearn.gaussian_process import GaussianProcessRegressor\n",
"from sklearn.gaussian_process.kernels import RBF, Matern, RationalQuadratic\n",
"from sklearn.preprocessing import StandardScaler\n",
"\n",
"# Constants\n",
"SAMPLE_RATE = 19531\n",
"LEAD_LIMIT = 128\n",
"TIME_LIMIT_SEC = 0.1\n",
"DATA_DIR = \"data\"\n",
"RECONSTRUCTION_START_FRACTION = 0.5 # Fraction of time at which to start reconstruction\n",
"KERNEL_TYPE = 'RBF' # Configurable kernel: 'RBF', 'Matern', 'RationalQuadratic'\n",
"ALPHA = 1e-6 # Small noise term to ensure positive definite kernel\n",
"NUM_GP_SAMPLES = 1 # Number of GP samples to draw\n",
"TOP_N = 3 # Number of most correlated leads to consider\n",
"GP_TRAINING_POINTS = 64 # Number of recent points to use for GP fitting\n",
"WINDOW_SIZE = 10 # Window size for autoregressive GP\n",
"\n",
"def load_and_trim_data(data_dir, sample_rate, lead_limit, time_limit_sec):\n",
" \"\"\"Load and trim the data to specified lead and time limits.\"\"\"\n",
" all_data = load_all_wavs(data_dir)\n",
" trimmed_data = []\n",
" max_samples = int(sample_rate * time_limit_sec)\n",
" for lead in all_data[:lead_limit]:\n",
" trimmed_data.append(lead[:max_samples])\n",
" return np.array(trimmed_data)\n",
"\n",
"def plot_combined_reconstruction(data, reconstruction_start_fraction, kernel_type, num_gp_samples, top_n, gp_training_points, window_size):\n",
" \"\"\"Plot a random lead with combined GP extrapolation and correlated lead contributions.\"\"\"\n",
" num_leads = data.shape[0]\n",
" random_lead_index = random.randint(0, num_leads - 1)\n",
" \n",
" # Get the random lead data\n",
" random_lead_data = data[random_lead_index]\n",
" \n",
" # Normalize the data\n",
" scaler = StandardScaler()\n",
" random_lead_data_normalized = scaler.fit_transform(random_lead_data.reshape(-1, 1)).flatten()\n",
" \n",
" # Calculate correlations\n",
" correlations = np.corrcoef(data)\n",
" np.fill_diagonal(correlations, 0) # Ignore self-correlations\n",
" \n",
" # Find the top N most correlated leads\n",
" most_correlated_indices = np.argsort(correlations[random_lead_index])[-top_n:]\n",
" \n",
" # Plot the original random lead\n",
" fig = go.Figure()\n",
" time = np.linspace(0, len(random_lead_data) / SAMPLE_RATE, num=len(random_lead_data))\n",
" fig.add_trace(go.Scatter(\n",
" x=time,\n",
" y=random_lead_data,\n",
" mode='lines',\n",
" name=f'Original Lead {random_lead_index+1}'\n",
" ))\n",
"\n",
" # Determine reconstruction start index\n",
" reconstruction_start_index = int(len(random_lead_data) * reconstruction_start_fraction)\n",
" \n",
" # Fit Gaussian Process on the initial training data\n",
" kernel = None\n",
" if kernel_type == 'RBF':\n",
" kernel = RBF(length_scale=1.0, length_scale_bounds=(1e-5, 1e5))\n",
" elif kernel_type == 'Matern':\n",
" kernel = Matern(length_scale=1.0, length_scale_bounds=(1e-5, 1e5))\n",
" elif kernel_type == 'RationalQuadratic':\n",
" kernel = RationalQuadratic(length_scale=1.0, length_scale_bounds=(1e-5, 1e5))\n",
" \n",
" gp = GaussianProcessRegressor(kernel=kernel, alpha=ALPHA)\n",
"\n",
" extrapolated_mean = np.zeros(len(random_lead_data) - reconstruction_start_index)\n",
" extrapolated_std = np.zeros(len(random_lead_data) - reconstruction_start_index)\n",
" combined_reconstruction = np.zeros(len(random_lead_data)) + random_lead_data[reconstruction_start_index]\n",
" reconstructed_data = np.zeros(len(random_lead_data))\n",
" \n",
" for i in range(reconstruction_start_index, len(random_lead_data)):\n",
" # Use past samples to ensure the window is full\n",
" start_index = max(0, i - window_size)\n",
" \n",
" # Prepare training data for GP\n",
" train_time = time[start_index:i].reshape(-1, 1)\n",
" train_data = combined_reconstruction[start_index:i]\n",
" train_data_normalized = scaler.transform(train_data.reshape(-1, 1)).flatten()\n",
" \n",
" # Fit the GP on the current window\n",
" gp.fit(train_time, train_data_normalized)\n",
" \n",
" # Predict the next value\n",
" next_time = np.array([[time[i]]])\n",
" mean, std = gp.predict(next_time, return_std=True)\n",
" \n",
" # Inverse transform the normalized predictions\n",
" mean = scaler.inverse_transform(mean.reshape(-1, 1)).flatten()[0]\n",
" std = std[0] * scaler.scale_[0]\n",
" \n",
" # Calculate the variance\n",
" gp_variance = std ** 2\n",
" \n",
" # Calculate contributions from the most correlated leads\n",
" correlated_contribution = combined_reconstruction[i-1]\n",
" for correlated_index in most_correlated_indices:\n",
" correlated_contribution += correlations[random_lead_index, correlated_index] * (data[correlated_index, i] - data[correlated_index, i-1])\n",
" \n",
" # Combine GP prediction with peer-based reconstruction\n",
" combined_value = (mean / gp_variance + correlated_contribution) / (1 / gp_variance + 1)\n",
" combined_reconstruction[i] = combined_value\n",
" \n",
" # Update the reconstructed data\n",
" reconstructed_data[i] = correlated_contribution\n",
" \n",
" # Store GP predictions\n",
" if i - reconstruction_start_index < len(extrapolated_mean):\n",
" extrapolated_mean[i - reconstruction_start_index] = mean\n",
" extrapolated_std[i - reconstruction_start_index] = std\n",
"\n",
" # Plot the combined reconstruction\n",
" fig.add_trace(go.Scatter(\n",
" x=time[reconstruction_start_index:],\n",
" y=combined_reconstruction[reconstruction_start_index:],\n",
" mode='lines',\n",
" name=f'Combined Reconstruction Lead {random_lead_index+1}',\n",
" line=dict(dash='dash')\n",
" ))\n",
"\n",
" fig.add_trace(go.Scatter(\n",
" x=time[reconstruction_start_index:],\n",
" y=reconstructed_data[reconstruction_start_index:],\n",
" mode='lines',\n",
" name=f'Peer Reconstruction Lead {random_lead_index+1}',\n",
" line=dict(dash='dot')\n",
" ))\n",
" \n",
" fig.update_layout(\n",
" title=f'Lead {random_lead_index+1} with Combined GP and Correlated Lead Contributions',\n",
" xaxis_title='Time [s]',\n",
" yaxis_title='Amplitude',\n",
" legend_title_text='Data'\n",
" )\n",
" fig.show()\n",
"\n",
"# Load and trim the data\n",
"trimmed_data = load_and_trim_data(DATA_DIR, SAMPLE_RATE, LEAD_LIMIT, TIME_LIMIT_SEC)\n",
"\n",
"# Plot combined reconstruction\n",
"for _ in range(8):\n",
" plot_combined_reconstruction(trimmed_data, RECONSTRUCTION_START_FRACTION, KERNEL_TYPE, NUM_GP_SAMPLES, TOP_N, GP_TRAINING_POINTS, WINDOW_SIZE)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
2024-05-25 17:33:19 +02:00
"id": "4642570b",
2024-05-25 17:31:08 +02:00
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}