Spikey/cli.py

51 lines
2.0 KiB
Python
Raw Normal View History

2024-05-24 22:01:59 +02:00
import argparse
import yaml
import os
import torch
from data_processing import download_and_extract_data, load_all_wavs, save_wav, delta_encode, delta_decode
from main import SpikeRunner
def load_config(config_path):
with open(config_path, 'r') as file:
config = yaml.safe_load(file)
return config
def main():
parser = argparse.ArgumentParser(description="WAV Compression with Neural Networks")
parser.add_argument('action', choices=['compress', 'decompress'], help="Action to perform")
parser.add_argument('--config', default='config.yaml', help="Path to the configuration file")
parser.add_argument('--input_file', help="Path to the input WAV file")
parser.add_argument('--output_file', help="Path to the output file")
args = parser.parse_args()
config = load_config(args.config)
spike_runner = SpikeRunner(None, config)
spike_runner.setup('CLI')
if args.action == 'compress':
data = load_all_wavs(args.input_file)
if spike_runner.slate.consume(config['preprocessing'], 'use_delta_encoding'):
data = [delta_encode(d) for d in data]
spike_runner.encoder.build_model(data)
encoded_data = [spike_runner.model(torch.tensor(d, dtype=torch.float32).unsqueeze(0)).squeeze(0).detach().numpy().tolist() for d in data]
compressed_data = [spike_runner.encoder.encode(ed) for ed in encoded_data]
with open(args.output_file, 'wb') as f:
for cd in compressed_data:
f.write(bytearray(cd))
elif args.action == 'decompress':
with open(args.input_file, 'rb') as f:
compressed_data = list(f.read())
decoded_data = [spike_runner.encoder.decode(cd, len(cd)) for cd in compressed_data]
if spike_runner.slate.consume(config['preprocessing'], 'use_delta_encoding'):
decoded_data = [delta_decode(dd) for dd in decoded_data]
save_wav(args.output_file, 19531, decoded_data) # Assuming 19531 Hz sample rate
if __name__ == "__main__":
main()