Fixed data path

This commit is contained in:
Dominik Moritz Roth 2024-05-25 17:44:12 +02:00
parent 1a44b0efad
commit bfbf9019d5
2 changed files with 6 additions and 8 deletions

View File

@ -4,13 +4,12 @@ import urllib.request
import zipfile
import os
def download_and_extract_data(url, data_dir):
if not os.path.exists(data_dir):
os.makedirs(data_dir)
zip_path = os.path.join(data_dir, 'data.zip')
def download_and_extract_data(url):
if not os.path.exists('data'):
zip_path = os.path.join('.', 'data.zip')
urllib.request.urlretrieve(url, zip_path)
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(data_dir)
zip_ref.extractall('.')
os.remove(zip_path)
def load_wav(file_path):

View File

@ -22,10 +22,9 @@ class SpikeRunner:
data_config = slate.consume(config, 'data', expand=True)
data_url = slate.consume(data_config, 'url')
data_dir = slate.consume(data_config, 'directory')
cut_length = slate.consume(data_config, 'cut_length', None)
download_and_extract_data(data_url, data_dir)
all_data = load_all_wavs(data_dir, cut_length)
download_and_extract_data(data_url)
all_data = load_all_wavs('data', cut_length)
split_ratio = slate.consume(data_config, 'split_ratio', 0.5)
self.train_data, self.test_data = split_data_by_time(all_data, split_ratio)