51 lines
2.0 KiB
Python
51 lines
2.0 KiB
Python
import argparse
|
|
import yaml
|
|
import os
|
|
import torch
|
|
from data_processing import download_and_extract_data, load_all_wavs, save_wav, delta_encode, delta_decode
|
|
from main import SpikeRunner
|
|
|
|
def load_config(config_path):
|
|
with open(config_path, 'r') as file:
|
|
config = yaml.safe_load(file)
|
|
return config
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="WAV Compression with Neural Networks")
|
|
parser.add_argument('action', choices=['compress', 'decompress'], help="Action to perform")
|
|
parser.add_argument('--config', default='config.yaml', help="Path to the configuration file")
|
|
parser.add_argument('--input_file', help="Path to the input WAV file")
|
|
parser.add_argument('--output_file', help="Path to the output file")
|
|
args = parser.parse_args()
|
|
|
|
config = load_config(args.config)
|
|
|
|
spike_runner = SpikeRunner(None, config)
|
|
spike_runner.setup('CLI')
|
|
|
|
if args.action == 'compress':
|
|
data = load_all_wavs(args.input_file)
|
|
if spike_runner.slate.consume(config['preprocessing'], 'use_delta_encoding'):
|
|
data = [delta_encode(d) for d in data]
|
|
|
|
spike_runner.encoder.build_model(data)
|
|
encoded_data = [spike_runner.model(torch.tensor(d, dtype=torch.float32).unsqueeze(0)).squeeze(0).detach().numpy().tolist() for d in data]
|
|
compressed_data = [spike_runner.encoder.encode(ed) for ed in encoded_data]
|
|
|
|
with open(args.output_file, 'wb') as f:
|
|
for cd in compressed_data:
|
|
f.write(bytearray(cd))
|
|
|
|
elif args.action == 'decompress':
|
|
with open(args.input_file, 'rb') as f:
|
|
compressed_data = list(f.read())
|
|
|
|
decoded_data = [spike_runner.encoder.decode(cd, len(cd)) for cd in compressed_data]
|
|
if spike_runner.slate.consume(config['preprocessing'], 'use_delta_encoding'):
|
|
decoded_data = [delta_decode(dd) for dd in decoded_data]
|
|
|
|
save_wav(args.output_file, 19531, decoded_data) # Assuming 19531 Hz sample rate
|
|
|
|
if __name__ == "__main__":
|
|
main()
|