diff --git a/data_processing.py b/data_processing.py index 3ba35be..87163e7 100644 --- a/data_processing.py +++ b/data_processing.py @@ -4,13 +4,12 @@ import urllib.request import zipfile import os -def download_and_extract_data(url, data_dir): - if not os.path.exists(data_dir): - os.makedirs(data_dir) - zip_path = os.path.join(data_dir, 'data.zip') +def download_and_extract_data(url): + if not os.path.exists('data'): + zip_path = os.path.join('.', 'data.zip') urllib.request.urlretrieve(url, zip_path) with zipfile.ZipFile(zip_path, 'r') as zip_ref: - zip_ref.extractall(data_dir) + zip_ref.extractall('.') os.remove(zip_path) def load_wav(file_path): diff --git a/main.py b/main.py index 197e0d8..21ea06f 100644 --- a/main.py +++ b/main.py @@ -22,10 +22,9 @@ class SpikeRunner: data_config = slate.consume(config, 'data', expand=True) data_url = slate.consume(data_config, 'url') - data_dir = slate.consume(data_config, 'directory') cut_length = slate.consume(data_config, 'cut_length', None) - download_and_extract_data(data_url, data_dir) - all_data = load_all_wavs(data_dir, cut_length) + download_and_extract_data(data_url) + all_data = load_all_wavs('data', cut_length) split_ratio = slate.consume(data_config, 'split_ratio', 0.5) self.train_data, self.test_data = split_data_by_time(all_data, split_ratio)