Spikey/data_processing.py

47 lines
1.4 KiB
Python
Raw Normal View History

2024-05-24 22:01:59 +02:00
import numpy as np
from scipy.io import wavfile
import urllib.request
import zipfile
2024-05-24 23:02:24 +02:00
import os
2024-05-24 22:01:59 +02:00
def download_and_extract_data(url, data_dir):
if not os.path.exists(data_dir):
os.makedirs(data_dir)
zip_path = os.path.join(data_dir, 'data.zip')
urllib.request.urlretrieve(url, zip_path)
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(data_dir)
os.remove(zip_path)
def load_wav(file_path):
sample_rate, data = wavfile.read(file_path)
return sample_rate, data
2024-05-25 17:31:08 +02:00
def load_all_wavs(data_dir, cut_length=None):
2024-05-24 22:01:59 +02:00
wav_files = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith('.wav')]
all_data = []
for file_path in wav_files:
_, data = load_wav(file_path)
2024-05-25 17:31:08 +02:00
if cut_length:
data = data[:cut_length]
2024-05-24 22:01:59 +02:00
all_data.append(data)
return all_data
2024-05-25 17:31:08 +02:00
def compute_correlation_matrix(data):
num_leads = len(data)
corr_matrix = np.zeros((num_leads, num_leads))
for i in range(num_leads):
for j in range(num_leads):
if i != j:
corr_matrix[i, j] = np.corrcoef(data[i], data[j])[0, 1]
return corr_matrix
2024-05-24 22:01:59 +02:00
2024-05-25 17:31:08 +02:00
def split_data_by_time(data, split_ratio=0.5):
train_data = []
test_data = []
for lead in data:
split_idx = int(len(lead) * split_ratio)
train_data.append(lead[:split_idx])
test_data.append(lead[split_idx:])
return train_data, test_data