2nd commit
This commit is contained in:
parent
0c5f888d75
commit
e0f51b5ee0
34
README.md
34
README.md
@ -18,40 +18,6 @@ pip install -r requirements.txt
|
|||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
### Configuration
|
|
||||||
|
|
||||||
The configuration for training and evaluation is specified in a YAML file. Below is an example configuration:
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
name: Test
|
|
||||||
|
|
||||||
preprocessing:
|
|
||||||
use_delta_encoding: true # Whether to use delta encoding.
|
|
||||||
|
|
||||||
predictor:
|
|
||||||
type: lstm # Options: 'lstm', 'fixed_input_nn'
|
|
||||||
input_size: 1 # Input size for the LSTM predictor.
|
|
||||||
hidden_size: 128 # Hidden size for the LSTM or Fixed Input NN predictor.
|
|
||||||
num_layers: 2 # Number of layers for the LSTM predictor.
|
|
||||||
fixed_input_size: 10 # Input size for the Fixed Input NN predictor. Only used if type is 'fixed_input_nn'.
|
|
||||||
|
|
||||||
training:
|
|
||||||
epochs: 10 # Number of training epochs.
|
|
||||||
batch_size: 32 # Batch size for training.
|
|
||||||
learning_rate: 0.001 # Learning rate for the optimizer.
|
|
||||||
eval_freq: 2 # Frequency of evaluation during training (in epochs).
|
|
||||||
save_path: models # Directory to save the best model and encoder.
|
|
||||||
num_points: 1000 # Number of data points to visualize.
|
|
||||||
|
|
||||||
bitstream_encoding:
|
|
||||||
type: arithmetic # Use arithmetic encoding.
|
|
||||||
|
|
||||||
data:
|
|
||||||
url: https://content.neuralink.com/compression-challenge/data.zip # URL to download the dataset.
|
|
||||||
directory: data # Directory to extract and store the dataset.
|
|
||||||
split_ratio: 0.8 # Ratio to split the data into train and test sets.
|
|
||||||
```
|
|
||||||
|
|
||||||
### Running the Code
|
### Running the Code
|
||||||
|
|
||||||
To train the model and compress/decompress WAV files, use the CLI provided:
|
To train the model and compress/decompress WAV files, use the CLI provided:
|
||||||
|
23
bitstream.py
23
bitstream.py
@ -1,3 +1,4 @@
|
|||||||
|
import bz2
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from arithmetic_compressor import AECompressor
|
from arithmetic_compressor import AECompressor
|
||||||
from arithmetic_compressor.models import StaticModel
|
from arithmetic_compressor.models import StaticModel
|
||||||
@ -15,6 +16,16 @@ class BaseEncoder(ABC):
|
|||||||
def build_model(self, data):
|
def build_model(self, data):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
class IdentityEncoder(BaseEncoder):
|
||||||
|
def encode(self, data):
|
||||||
|
return data
|
||||||
|
|
||||||
|
def decode(self, encoded_data, num_symbols):
|
||||||
|
return encoded_data
|
||||||
|
|
||||||
|
def build_model(self, data):
|
||||||
|
pass
|
||||||
|
|
||||||
class ArithmeticEncoder(BaseEncoder):
|
class ArithmeticEncoder(BaseEncoder):
|
||||||
def encode(self, data):
|
def encode(self, data):
|
||||||
if not hasattr(self, 'model'):
|
if not hasattr(self, 'model'):
|
||||||
@ -29,7 +40,19 @@ class ArithmeticEncoder(BaseEncoder):
|
|||||||
return decoded_data
|
return decoded_data
|
||||||
|
|
||||||
def build_model(self, data):
|
def build_model(self, data):
|
||||||
|
# Convert data to list of tuples
|
||||||
|
data = [tuple(d) for d in data]
|
||||||
symbol_counts = {symbol: data.count(symbol) for symbol in set(data)}
|
symbol_counts = {symbol: data.count(symbol) for symbol in set(data)}
|
||||||
total_symbols = sum(symbol_counts.values())
|
total_symbols = sum(symbol_counts.values())
|
||||||
probabilities = {symbol: count / total_symbols for symbol, count in symbol_counts.items()}
|
probabilities = {symbol: count / total_symbols for symbol, count in symbol_counts.items()}
|
||||||
self.model = StaticModel(probabilities)
|
self.model = StaticModel(probabilities)
|
||||||
|
|
||||||
|
class Bzip2Encoder(BaseEncoder):
|
||||||
|
def encode(self, data):
|
||||||
|
return bz2.compress(bytearray(data))
|
||||||
|
|
||||||
|
def decode(self, encoded_data, num_symbols):
|
||||||
|
return list(bz2.decompress(encoded_data))
|
||||||
|
|
||||||
|
def build_model(self, data):
|
||||||
|
pass
|
||||||
|
18
config.yaml
18
config.yaml
@ -30,38 +30,42 @@ wandb:
|
|||||||
group: '{config[name]}'
|
group: '{config[name]}'
|
||||||
job_type: '{delta_desc}'
|
job_type: '{delta_desc}'
|
||||||
name: '{job_id}_{task_id}:{run_id}:{rand}={config[name]}_{delta_desc}'
|
name: '{job_id}_{task_id}:{run_id}:{rand}={config[name]}_{delta_desc}'
|
||||||
tags:
|
#tags:
|
||||||
- '{config[env][name]}'
|
# - '{config[env][name]}'
|
||||||
- '{config[algo][name]}'
|
# - '{config[algo][name]}'
|
||||||
sync_tensorboard: False
|
sync_tensorboard: False
|
||||||
monitor_gym: False
|
monitor_gym: False
|
||||||
save_code: False
|
save_code: False
|
||||||
|
|
||||||
---
|
---
|
||||||
name: Test
|
name: Test
|
||||||
|
import: $
|
||||||
|
|
||||||
preprocessing:
|
preprocessing:
|
||||||
use_delta_encoding: true # Whether to use delta encoding.
|
use_delta_encoding: false # Whether to use delta encoding.
|
||||||
|
|
||||||
predictor:
|
predictor:
|
||||||
type: lstm # Options: 'lstm', 'fixed_input_nn'
|
type: lstm # Options: 'lstm', 'fixed_input_nn'
|
||||||
input_size: 1 # Input size for the LSTM predictor.
|
input_size: 1 # Input size for the LSTM predictor.
|
||||||
hidden_size: 128 # Hidden size for the LSTM or Fixed Input NN predictor.
|
hidden_size: 16 # Hidden size for the LSTM or Fixed Input NN predictor.
|
||||||
num_layers: 2 # Number of layers for the LSTM predictor.
|
num_layers: 2 # Number of layers for the LSTM predictor.
|
||||||
fixed_input_size: 10 # Input size for the Fixed Input NN predictor. Only used if type is 'fixed_input_nn'.
|
fixed_input_size: 10 # Input size for the Fixed Input NN predictor. Only used if type is 'fixed_input_nn'.
|
||||||
|
|
||||||
training:
|
training:
|
||||||
epochs: 10 # Number of training epochs.
|
epochs: 10 # Number of training epochs.
|
||||||
batch_size: 32 # Batch size for training.
|
batch_size: 8 # Batch size for training.
|
||||||
learning_rate: 0.001 # Learning rate for the optimizer.
|
learning_rate: 0.001 # Learning rate for the optimizer.
|
||||||
eval_freq: 2 # Frequency of evaluation during training (in epochs).
|
eval_freq: 2 # Frequency of evaluation during training (in epochs).
|
||||||
save_path: models # Directory to save the best model and encoder.
|
save_path: models # Directory to save the best model and encoder.
|
||||||
num_points: 1000 # Number of data points to visualize
|
num_points: 1000 # Number of data points to visualize
|
||||||
|
|
||||||
bitstream_encoding:
|
bitstream_encoding:
|
||||||
type: arithmetic # Use arithmetic encoding.
|
type: identity # Options: 'arithmetic', 'no_compression', 'bzip2'
|
||||||
|
|
||||||
data:
|
data:
|
||||||
url: https://content.neuralink.com/compression-challenge/data.zip # URL to download the dataset.
|
url: https://content.neuralink.com/compression-challenge/data.zip # URL to download the dataset.
|
||||||
directory: data # Directory to extract and store the dataset.
|
directory: data # Directory to extract and store the dataset.
|
||||||
split_ratio: 0.8 # Ratio to split the data into train and test sets.
|
split_ratio: 0.8 # Ratio to split the data into train and test sets.
|
||||||
|
|
||||||
|
profiler:
|
||||||
|
enable: false
|
@ -1,8 +1,8 @@
|
|||||||
import os
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy.io import wavfile
|
from scipy.io import wavfile
|
||||||
import urllib.request
|
import urllib.request
|
||||||
import zipfile
|
import zipfile
|
||||||
|
import os
|
||||||
|
|
||||||
def download_and_extract_data(url, data_dir):
|
def download_and_extract_data(url, data_dir):
|
||||||
if not os.path.exists(data_dir):
|
if not os.path.exists(data_dir):
|
||||||
@ -35,7 +35,8 @@ def delta_encode(data):
|
|||||||
"""Apply delta encoding to the data."""
|
"""Apply delta encoding to the data."""
|
||||||
deltas = [data[0]]
|
deltas = [data[0]]
|
||||||
for i in range(1, len(data)):
|
for i in range(1, len(data)):
|
||||||
deltas.append(data[i] - data[i - 1])
|
delta = np.subtract(data[i], data[i - 1], dtype=np.float32) # Using numpy subtract to handle overflow
|
||||||
|
deltas.append(delta)
|
||||||
return deltas
|
return deltas
|
||||||
|
|
||||||
def delta_decode(deltas):
|
def delta_decode(deltas):
|
||||||
|
55
main.py
55
main.py
@ -1,9 +1,12 @@
|
|||||||
import yaml
|
|
||||||
from slate import Slate, Slate_Runner
|
from slate import Slate, Slate_Runner
|
||||||
|
|
||||||
|
from pycallgraph2 import PyCallGraph
|
||||||
|
from pycallgraph2.output import GraphvizOutput
|
||||||
|
|
||||||
from data_processing import download_and_extract_data, load_all_wavs, delta_encode
|
from data_processing import download_and_extract_data, load_all_wavs, delta_encode
|
||||||
from model import LSTMPredictor, FixedInputNNPredictor
|
from model import LSTMPredictor, FixedInputNNPredictor
|
||||||
from train import train_model
|
from train import train_model
|
||||||
from bitstream import ArithmeticEncoder
|
from bitstream import ArithmeticEncoder, IdentityEncoder, Bzip2Encoder
|
||||||
|
|
||||||
class SpikeRunner(Slate_Runner):
|
class SpikeRunner(Slate_Runner):
|
||||||
def setup(self, name):
|
def setup(self, name):
|
||||||
@ -23,7 +26,14 @@ class SpikeRunner(Slate_Runner):
|
|||||||
download_and_extract_data(data_url, data_dir)
|
download_and_extract_data(data_url, data_dir)
|
||||||
all_data = load_all_wavs(data_dir)
|
all_data = load_all_wavs(data_dir)
|
||||||
|
|
||||||
if slate.consume(preprocessing_config, 'use_delta_encoding'):
|
self.epochs = slate.consume(training_config, 'epochs')
|
||||||
|
self.batch_size = slate.consume(training_config, 'batch_size')
|
||||||
|
self.learning_rate = slate.consume(training_config, 'learning_rate')
|
||||||
|
self.use_delta_encoding = slate.consume(preprocessing_config, 'use_delta_encoding')
|
||||||
|
self.eval_freq = slate.consume(training_config, 'eval_freq')
|
||||||
|
self.save_path = slate.consume(training_config, 'save_path', 'models')
|
||||||
|
|
||||||
|
if self.use_delta_encoding:
|
||||||
all_data = [delta_encode(d) for d in all_data]
|
all_data = [delta_encode(d) for d in all_data]
|
||||||
|
|
||||||
# Split data into train and test sets
|
# Split data into train and test sets
|
||||||
@ -35,34 +45,47 @@ class SpikeRunner(Slate_Runner):
|
|||||||
# Model setup
|
# Model setup
|
||||||
self.model = self.get_model(predictor_config)
|
self.model = self.get_model(predictor_config)
|
||||||
self.encoder = self.get_encoder(bitstream_config)
|
self.encoder = self.get_encoder(bitstream_config)
|
||||||
self.epochs = slate.consume(training_config, 'epochs')
|
|
||||||
self.batch_size = slate.consume(training_config, 'batch_size')
|
|
||||||
self.learning_rate = slate.consume(training_config, 'learning_rate')
|
|
||||||
self.use_delta_encoding = slate.consume(preprocessing_config, 'use_delta_encoding')
|
|
||||||
self.eval_freq = slate.consume(training_config, 'eval_freq')
|
|
||||||
self.save_path = slate.consume(training_config, 'save_path', 'models')
|
|
||||||
|
|
||||||
def get_model(self, config):
|
def get_model(self, config):
|
||||||
model_type = self.slate.consume(config, 'type')
|
model_type = slate.consume(config, 'type')
|
||||||
if model_type == 'lstm':
|
if model_type == 'lstm':
|
||||||
return LSTMPredictor(
|
return LSTMPredictor(
|
||||||
input_size=self.slate.consume(config, 'input_size'),
|
input_size=slate.consume(config, 'input_size'),
|
||||||
hidden_size=self.slate.consume(config, 'hidden_size'),
|
hidden_size=slate.consume(config, 'hidden_size'),
|
||||||
num_layers=self.slate.consume(config, 'num_layers')
|
num_layers=slate.consume(config, 'num_layers')
|
||||||
)
|
)
|
||||||
elif model_type == 'fixed_input_nn':
|
elif model_type == 'fixed_input_nn':
|
||||||
return FixedInputNNPredictor(
|
return FixedInputNNPredictor(
|
||||||
input_size=self.slate.consume(config, 'fixed_input_size'),
|
input_size=slate.consume(config, 'fixed_input_size'),
|
||||||
hidden_size=self.slate.consume(config, 'hidden_size')
|
hidden_size=slate.consume(config, 'hidden_size')
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown model type: {model_type}")
|
raise ValueError(f"Unknown model type: {model_type}")
|
||||||
|
|
||||||
def get_encoder(self, config):
|
def get_encoder(self, config):
|
||||||
|
encoder_type = slate.consume(config, 'type')
|
||||||
|
if encoder_type == 'arithmetic':
|
||||||
return ArithmeticEncoder()
|
return ArithmeticEncoder()
|
||||||
|
elif encoder_type == 'identity':
|
||||||
|
return IdentityEncoder()
|
||||||
|
elif encoder_type == 'bzip2':
|
||||||
|
return Bzip2Encoder()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown encoder type: {encoder_type}")
|
||||||
|
|
||||||
def run(self, run, forceNoProfile=False):
|
def run(self, run, forceNoProfile=False):
|
||||||
train_model(self.model, self.train_data, self.test_data, self.epochs, self.batch_size, self.learning_rate, self.use_delta_encoding, self.encoder, self.eval_freq, self.save_path)
|
if self.slate.consume(self.config, 'profiler.enable', False) and not forceNoProfile:
|
||||||
|
print('{PROFILER RUNNING}')
|
||||||
|
with PyCallGraph(output=GraphvizOutput(output_file=f'./profiler/{self.name}.png')):
|
||||||
|
self.run(run, forceNoProfile=True)
|
||||||
|
print('{PROFILER DONE}')
|
||||||
|
return
|
||||||
|
|
||||||
|
train_model(
|
||||||
|
self.model, self.train_data, self.test_data,
|
||||||
|
self.epochs, self.batch_size, self.learning_rate,
|
||||||
|
self.use_delta_encoding, self.encoder, self.eval_freq, self.save_path
|
||||||
|
)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
slate = Slate({'spikey': SpikeRunner})
|
slate = Slate({'spikey': SpikeRunner})
|
||||||
|
35
model.py
35
model.py
@ -2,9 +2,9 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
class BaseModel(ABC, nn.Module):
|
class BaseModel(nn.Module, ABC):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super(BaseModel, self).__init__()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@ -23,12 +23,10 @@ class LSTMPredictor(BaseModel):
|
|||||||
super(LSTMPredictor, self).__init__()
|
super(LSTMPredictor, self).__init__()
|
||||||
self.rnn = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
|
self.rnn = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
|
||||||
self.fc = nn.Linear(hidden_size, 1)
|
self.fc = nn.Linear(hidden_size, 1)
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.num_layers = num_layers
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
|
h0 = torch.zeros(self.rnn.num_layers, x.size(0), self.rnn.hidden_size).to(x.device)
|
||||||
c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
|
c0 = torch.zeros(self.rnn.num_layers, x.size(0), self.rnn.hidden_size).to(x.device)
|
||||||
out, _ = self.rnn(x, (h0, c0))
|
out, _ = self.rnn(x, (h0, c0))
|
||||||
out = self.fc(out)
|
out = self.fc(out)
|
||||||
return out
|
return out
|
||||||
@ -39,8 +37,10 @@ class LSTMPredictor(BaseModel):
|
|||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for i in range(len(data) - 1):
|
for i in range(len(data) - 1):
|
||||||
context = torch.tensor(data[max(0, i - self.hidden_size):i]).view(1, -1, 1).float()
|
context = torch.tensor(data[max(0, i - self.rnn.hidden_size):i], dtype=torch.float32).unsqueeze(0).unsqueeze(2).to(next(self.parameters()).device)
|
||||||
prediction = self.forward(context).item()
|
if context.shape[1] == 0:
|
||||||
|
context = torch.zeros((1, 1, 1)).to(next(self.parameters()).device)
|
||||||
|
prediction = self.forward(context).cpu().numpy()[0][0]
|
||||||
delta = data[i] - prediction
|
delta = data[i] - prediction
|
||||||
encoded_data.append(delta)
|
encoded_data.append(delta)
|
||||||
|
|
||||||
@ -52,8 +52,10 @@ class LSTMPredictor(BaseModel):
|
|||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for i in range(len(encoded_data)):
|
for i in range(len(encoded_data)):
|
||||||
context = torch.tensor(decoded_data[max(0, i - self.hidden_size):i]).view(1, -1, 1).float()
|
context = torch.tensor(decoded_data[max(0, i - self.rnn.hidden_size):i], dtype=torch.float32).unsqueeze(0).unsqueeze(2).to(next(self.parameters()).device)
|
||||||
prediction = self.forward(context).item()
|
if context.shape[1] == 0:
|
||||||
|
context = torch.zeros((1, 1, 1)).to(next(self.parameters()).device)
|
||||||
|
prediction = self.forward(context).cpu().numpy()[0][0]
|
||||||
decoded_data.append(prediction + encoded_data[i])
|
decoded_data.append(prediction + encoded_data[i])
|
||||||
|
|
||||||
return decoded_data
|
return decoded_data
|
||||||
@ -64,7 +66,6 @@ class FixedInputNNPredictor(BaseModel):
|
|||||||
self.fc1 = nn.Linear(input_size, hidden_size)
|
self.fc1 = nn.Linear(input_size, hidden_size)
|
||||||
self.relu = nn.ReLU()
|
self.relu = nn.ReLU()
|
||||||
self.fc2 = nn.Linear(hidden_size, 1)
|
self.fc2 = nn.Linear(hidden_size, 1)
|
||||||
self.input_size = input_size
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.fc1(x)
|
x = self.fc1(x)
|
||||||
@ -77,10 +78,10 @@ class FixedInputNNPredictor(BaseModel):
|
|||||||
encoded_data = []
|
encoded_data = []
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for i in range(len(data) - self.input_size):
|
for i in range(len(data) - self.fc1.in_features):
|
||||||
context = torch.tensor(data[i:i + self.input_size]).view(1, -1).float()
|
context = torch.tensor(data[i:i + self.fc1.in_features], dtype=torch.float32).unsqueeze(0).to(next(self.parameters()).device)
|
||||||
prediction = self.forward(context).item()
|
prediction = self.forward(context).cpu().numpy()[0][0]
|
||||||
delta = data[i + self.input_size] - prediction
|
delta = data[i + self.fc1.in_features] - prediction
|
||||||
encoded_data.append(delta)
|
encoded_data.append(delta)
|
||||||
|
|
||||||
return encoded_data
|
return encoded_data
|
||||||
@ -91,8 +92,8 @@ class FixedInputNNPredictor(BaseModel):
|
|||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for i in range(len(encoded_data)):
|
for i in range(len(encoded_data)):
|
||||||
context = torch.tensor(decoded_data[max(0, i - self.input_size):i]).view(1, -1).float()
|
context = torch.tensor(decoded_data[max(0, i - self.fc1.in_features):i], dtype=torch.float32).unsqueeze(0).to(next(self.parameters()).device)
|
||||||
prediction = self.forward(context).item()
|
prediction = self.forward(context).cpu().numpy()[0][0]
|
||||||
decoded_data.append(prediction + encoded_data[i])
|
decoded_data.append(prediction + encoded_data[i])
|
||||||
|
|
||||||
return decoded_data
|
return decoded_data
|
||||||
|
@ -4,3 +4,4 @@ scipy
|
|||||||
matplotlib
|
matplotlib
|
||||||
wandb
|
wandb
|
||||||
pyyaml
|
pyyaml
|
||||||
|
arithmetic_compressor
|
||||||
|
32
train.py
32
train.py
@ -10,17 +10,17 @@ from data_processing import delta_encode, delta_decode, save_wav
|
|||||||
from utils import visualize_prediction, plot_delta_distribution
|
from utils import visualize_prediction, plot_delta_distribution
|
||||||
from bitstream import ArithmeticEncoder
|
from bitstream import ArithmeticEncoder
|
||||||
|
|
||||||
def evaluate_model(model, data, use_delta_encoding, encoder, sample_rate=19531, epoch=0, num_points=None):
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
|
||||||
|
def evaluate_model(model, data, use_delta_encoding, encoder, sample_rate=19531, epoch=0):
|
||||||
compression_ratios = []
|
compression_ratios = []
|
||||||
identical_count = 0
|
identical_count = 0
|
||||||
all_deltas = []
|
all_deltas = []
|
||||||
|
|
||||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
model.eval()
|
||||||
model.to(device)
|
|
||||||
|
|
||||||
for file_data in data:
|
for file_data in data:
|
||||||
file_data = torch.tensor(file_data, dtype=torch.float32).unsqueeze(1).to(device)
|
file_data = torch.tensor(file_data, dtype=torch.float32).unsqueeze(1).to(device)
|
||||||
encoded_data = model(file_data).squeeze(1).cpu().detach().numpy().tolist()
|
encoded_data = model.encode(file_data.squeeze(1).cpu().numpy())
|
||||||
encoder.build_model(encoded_data)
|
encoder.build_model(encoded_data)
|
||||||
compressed_data = encoder.encode(encoded_data)
|
compressed_data = encoder.encode(encoded_data)
|
||||||
decompressed_data = encoder.decode(compressed_data, len(encoded_data))
|
decompressed_data = encoder.decode(compressed_data, len(encoded_data))
|
||||||
@ -36,14 +36,14 @@ def evaluate_model(model, data, use_delta_encoding, encoder, sample_rate=19531,
|
|||||||
compression_ratios.append(compression_ratio)
|
compression_ratios.append(compression_ratio)
|
||||||
|
|
||||||
# Compute and collect deltas
|
# Compute and collect deltas
|
||||||
predicted_data = model(torch.tensor(encoded_data, dtype=torch.float32).unsqueeze(1).to(device)).squeeze(1).cpu().detach().numpy().tolist()
|
predicted_data = model.decode(encoded_data)
|
||||||
if use_delta_encoding:
|
if use_delta_encoding:
|
||||||
predicted_data = delta_decode(predicted_data)
|
predicted_data = delta_decode(predicted_data)
|
||||||
delta_data = [file_data[i].item() - predicted_data[i] for i in range(len(file_data))]
|
delta_data = [file_data[i].item() - predicted_data[i] for i in range(len(file_data))]
|
||||||
all_deltas.extend(delta_data)
|
all_deltas.extend(delta_data)
|
||||||
|
|
||||||
# Visualize prediction vs data vs error
|
# Visualize prediction vs data vs error
|
||||||
visualize_prediction(file_data.cpu().numpy(), predicted_data, delta_data, sample_rate, num_points)
|
visualize_prediction(file_data.cpu().numpy(), predicted_data, delta_data, sample_rate)
|
||||||
|
|
||||||
identical_percentage = (identical_count / len(data)) * 100
|
identical_percentage = (identical_count / len(data)) * 100
|
||||||
|
|
||||||
@ -53,22 +53,24 @@ def evaluate_model(model, data, use_delta_encoding, encoder, sample_rate=19531,
|
|||||||
|
|
||||||
return compression_ratios, identical_percentage
|
return compression_ratios, identical_percentage
|
||||||
|
|
||||||
def train_model(model, train_data, test_data, epochs, batch_size, learning_rate, use_delta_encoding, encoder, eval_freq, save_path, num_points=None):
|
def train_model(model, train_data, test_data, epochs, batch_size, learning_rate, use_delta_encoding, encoder, eval_freq, save_path):
|
||||||
"""Train the model."""
|
"""Train the model."""
|
||||||
wandb.init(project="wav-compression")
|
wandb.init(project="wav-compression")
|
||||||
criterion = nn.MSELoss()
|
criterion = nn.MSELoss()
|
||||||
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
|
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
|
||||||
best_test_score = float('inf')
|
best_test_score = float('inf')
|
||||||
|
model = model.to(device)
|
||||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
||||||
model.to(device)
|
|
||||||
|
|
||||||
for epoch in range(epochs):
|
for epoch in range(epochs):
|
||||||
|
model.train()
|
||||||
total_loss = 0
|
total_loss = 0
|
||||||
random.shuffle(train_data) # Shuffle data for varied batches
|
random.shuffle(train_data) # Shuffle data for varied batches
|
||||||
for i in range(0, len(train_data) - batch_size, batch_size):
|
for i in range(0, len(train_data) - batch_size, batch_size):
|
||||||
inputs = torch.tensor(train_data[i:i+batch_size], dtype=torch.float32).unsqueeze(2).to(device)
|
batch = train_data[i:i+batch_size]
|
||||||
targets = torch.tensor(train_data[i+1:i+batch_size+1], dtype=torch.float32).unsqueeze(2).to(device)
|
max_len = max(len(seq) for seq in batch)
|
||||||
|
padded_batch = np.array([np.pad(seq, (0, max_len - len(seq))) for seq in batch], dtype=np.float32)
|
||||||
|
inputs = torch.tensor(padded_batch[:, :-1], dtype=torch.float32).unsqueeze(2).to(device)
|
||||||
|
targets = torch.tensor(padded_batch[:, 1:], dtype=torch.float32).unsqueeze(2).to(device)
|
||||||
outputs = model(inputs)
|
outputs = model(inputs)
|
||||||
loss = criterion(outputs, targets)
|
loss = criterion(outputs, targets)
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
@ -81,8 +83,8 @@ def train_model(model, train_data, test_data, epochs, batch_size, learning_rate,
|
|||||||
|
|
||||||
if (epoch + 1) % eval_freq == 0:
|
if (epoch + 1) % eval_freq == 0:
|
||||||
# Evaluate on train and test data
|
# Evaluate on train and test data
|
||||||
train_compression_ratios, train_identical_percentage = evaluate_model(model, train_data, use_delta_encoding, encoder, epoch=epoch, num_points=num_points)
|
train_compression_ratios, train_identical_percentage = evaluate_model(model, train_data, use_delta_encoding, encoder, epoch=epoch)
|
||||||
test_compression_ratios, test_identical_percentage = evaluate_model(model, test_data, use_delta_encoding, encoder, epoch=epoch, num_points=num_points)
|
test_compression_ratios, test_identical_percentage = evaluate_model(model, test_data, use_delta_encoding, encoder, epoch=epoch)
|
||||||
|
|
||||||
# Log statistics
|
# Log statistics
|
||||||
wandb.log({
|
wandb.log({
|
||||||
|
Loading…
Reference in New Issue
Block a user