Fixed data path
This commit is contained in:
parent
1a44b0efad
commit
bfbf9019d5
@ -4,13 +4,12 @@ import urllib.request
|
||||
import zipfile
|
||||
import os
|
||||
|
||||
def download_and_extract_data(url, data_dir):
|
||||
if not os.path.exists(data_dir):
|
||||
os.makedirs(data_dir)
|
||||
zip_path = os.path.join(data_dir, 'data.zip')
|
||||
def download_and_extract_data(url):
|
||||
if not os.path.exists('data'):
|
||||
zip_path = os.path.join('.', 'data.zip')
|
||||
urllib.request.urlretrieve(url, zip_path)
|
||||
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
||||
zip_ref.extractall(data_dir)
|
||||
zip_ref.extractall('.')
|
||||
os.remove(zip_path)
|
||||
|
||||
def load_wav(file_path):
|
||||
|
5
main.py
5
main.py
@ -22,10 +22,9 @@ class SpikeRunner:
|
||||
data_config = slate.consume(config, 'data', expand=True)
|
||||
|
||||
data_url = slate.consume(data_config, 'url')
|
||||
data_dir = slate.consume(data_config, 'directory')
|
||||
cut_length = slate.consume(data_config, 'cut_length', None)
|
||||
download_and_extract_data(data_url, data_dir)
|
||||
all_data = load_all_wavs(data_dir, cut_length)
|
||||
download_and_extract_data(data_url)
|
||||
all_data = load_all_wavs('data', cut_length)
|
||||
|
||||
split_ratio = slate.consume(data_config, 'split_ratio', 0.5)
|
||||
self.train_data, self.test_data = split_data_by_time(all_data, split_ratio)
|
||||
|
Loading…
Reference in New Issue
Block a user