3
This commit is contained in:
parent
e0f51b5ee0
commit
82f7b26d53
16
config.yaml
16
config.yaml
@ -47,15 +47,15 @@ preprocessing:
|
||||
predictor:
|
||||
type: lstm # Options: 'lstm', 'fixed_input_nn'
|
||||
input_size: 1 # Input size for the LSTM predictor.
|
||||
hidden_size: 16 # Hidden size for the LSTM or Fixed Input NN predictor.
|
||||
hidden_size: 8 # 16 # Hidden size for the LSTM or Fixed Input NN predictor.
|
||||
num_layers: 2 # Number of layers for the LSTM predictor.
|
||||
fixed_input_size: 10 # Input size for the Fixed Input NN predictor. Only used if type is 'fixed_input_nn'.
|
||||
|
||||
training:
|
||||
epochs: 10 # Number of training epochs.
|
||||
epochs: 128 # Number of training epochs.
|
||||
batch_size: 8 # Batch size for training.
|
||||
learning_rate: 0.001 # Learning rate for the optimizer.
|
||||
eval_freq: 2 # Frequency of evaluation during training (in epochs).
|
||||
eval_freq: 8 # Frequency of evaluation during training (in epochs).
|
||||
save_path: models # Directory to save the best model and encoder.
|
||||
num_points: 1000 # Number of data points to visualize
|
||||
|
||||
@ -68,4 +68,12 @@ data:
|
||||
split_ratio: 0.8 # Ratio to split the data into train and test sets.
|
||||
|
||||
profiler:
|
||||
enable: false
|
||||
enable: false
|
||||
|
||||
ablative:
|
||||
training:
|
||||
learning_rate: [0.01, 0.0001, 0.00001]
|
||||
batch_size: [4, 16]
|
||||
predictor:
|
||||
hidden_size: [4, 16]
|
||||
num_layers: [1, 3]
|
||||
|
@ -35,13 +35,12 @@ def delta_encode(data):
|
||||
"""Apply delta encoding to the data."""
|
||||
deltas = [data[0]]
|
||||
for i in range(1, len(data)):
|
||||
delta = np.subtract(data[i], data[i - 1], dtype=np.float32) # Using numpy subtract to handle overflow
|
||||
deltas.append(delta)
|
||||
return deltas
|
||||
deltas.append(data[i] - data[i - 1])
|
||||
return np.array(deltas)
|
||||
|
||||
def delta_decode(deltas):
|
||||
"""Decode delta encoded data."""
|
||||
data = [deltas[0]]
|
||||
for i in range(1, len(deltas)):
|
||||
data.append(data[-1] + deltas[i])
|
||||
return data
|
||||
return np.array(data)
|
45
model.py
45
model.py
@ -2,7 +2,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
class BaseModel(nn.Module, ABC):
|
||||
class BaseModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(BaseModel, self).__init__()
|
||||
|
||||
@ -23,10 +23,12 @@ class LSTMPredictor(BaseModel):
|
||||
super(LSTMPredictor, self).__init__()
|
||||
self.rnn = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
|
||||
self.fc = nn.Linear(hidden_size, 1)
|
||||
self.hidden_size = hidden_size
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
def forward(self, x):
|
||||
h0 = torch.zeros(self.rnn.num_layers, x.size(0), self.rnn.hidden_size).to(x.device)
|
||||
c0 = torch.zeros(self.rnn.num_layers, x.size(0), self.rnn.hidden_size).to(x.device)
|
||||
h0 = torch.zeros(self.rnn.num_layers, x.size(0), self.rnn.hidden_size).to(self.device)
|
||||
c0 = torch.zeros(self.rnn.num_layers, x.size(0), self.rnn.hidden_size).to(self.device)
|
||||
out, _ = self.rnn(x, (h0, c0))
|
||||
out = self.fc(out)
|
||||
return out
|
||||
@ -35,12 +37,13 @@ class LSTMPredictor(BaseModel):
|
||||
self.eval()
|
||||
encoded_data = []
|
||||
|
||||
context_size = self.hidden_size # Define an appropriate context size
|
||||
with torch.no_grad():
|
||||
for i in range(len(data) - 1):
|
||||
context = torch.tensor(data[max(0, i - self.rnn.hidden_size):i], dtype=torch.float32).unsqueeze(0).unsqueeze(2).to(next(self.parameters()).device)
|
||||
if context.shape[1] == 0:
|
||||
context = torch.zeros((1, 1, 1)).to(next(self.parameters()).device)
|
||||
prediction = self.forward(context).cpu().numpy()[0][0]
|
||||
context = torch.tensor(data[max(0, i - context_size):i]).reshape(1, -1, 1).to(self.device)
|
||||
if context.size(1) == 0: # Handle empty context
|
||||
continue
|
||||
prediction = self.forward(context).squeeze(0).cpu().numpy()[0]
|
||||
delta = data[i] - prediction
|
||||
encoded_data.append(delta)
|
||||
|
||||
@ -50,12 +53,13 @@ class LSTMPredictor(BaseModel):
|
||||
self.eval()
|
||||
decoded_data = []
|
||||
|
||||
context_size = self.hidden_size # Define an appropriate context size
|
||||
with torch.no_grad():
|
||||
for i in range(len(encoded_data)):
|
||||
context = torch.tensor(decoded_data[max(0, i - self.rnn.hidden_size):i], dtype=torch.float32).unsqueeze(0).unsqueeze(2).to(next(self.parameters()).device)
|
||||
if context.shape[1] == 0:
|
||||
context = torch.zeros((1, 1, 1)).to(next(self.parameters()).device)
|
||||
prediction = self.forward(context).cpu().numpy()[0][0]
|
||||
context = torch.tensor(decoded_data[max(0, i - context_size):i]).reshape(1, -1, 1).to(self.device)
|
||||
if context.size(1) == 0: # Handle empty context
|
||||
continue
|
||||
prediction = self.forward(context).squeeze(0).cpu().numpy()[0]
|
||||
decoded_data.append(prediction + encoded_data[i])
|
||||
|
||||
return decoded_data
|
||||
@ -66,6 +70,7 @@ class FixedInputNNPredictor(BaseModel):
|
||||
self.fc1 = nn.Linear(input_size, hidden_size)
|
||||
self.relu = nn.ReLU()
|
||||
self.fc2 = nn.Linear(hidden_size, 1)
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
@ -77,11 +82,14 @@ class FixedInputNNPredictor(BaseModel):
|
||||
self.eval()
|
||||
encoded_data = []
|
||||
|
||||
context_size = self.fc1.in_features # Define an appropriate context size
|
||||
with torch.no_grad():
|
||||
for i in range(len(data) - self.fc1.in_features):
|
||||
context = torch.tensor(data[i:i + self.fc1.in_features], dtype=torch.float32).unsqueeze(0).to(next(self.parameters()).device)
|
||||
prediction = self.forward(context).cpu().numpy()[0][0]
|
||||
delta = data[i + self.fc1.in_features] - prediction
|
||||
for i in range(len(data) - context_size):
|
||||
context = torch.tensor(data[i:i + context_size]).reshape(1, -1).to(self.device)
|
||||
if context.size(1) == 0: # Handle empty context
|
||||
continue
|
||||
prediction = self.forward(context).squeeze(0).cpu().numpy()[0]
|
||||
delta = data[i + context_size] - prediction
|
||||
encoded_data.append(delta)
|
||||
|
||||
return encoded_data
|
||||
@ -90,10 +98,13 @@ class FixedInputNNPredictor(BaseModel):
|
||||
self.eval()
|
||||
decoded_data = []
|
||||
|
||||
context_size = self.fc1.in_features # Define an appropriate context size
|
||||
with torch.no_grad():
|
||||
for i in range(len(encoded_data)):
|
||||
context = torch.tensor(decoded_data[max(0, i - self.fc1.in_features):i], dtype=torch.float32).unsqueeze(0).to(next(self.parameters()).device)
|
||||
prediction = self.forward(context).cpu().numpy()[0][0]
|
||||
context = torch.tensor(decoded_data[max(0, i - context_size):i]).reshape(1, -1).to(self.device)
|
||||
if context.size(1) == 0: # Handle empty context
|
||||
continue
|
||||
prediction = self.forward(context).squeeze(0).cpu().numpy()[0]
|
||||
decoded_data.append(prediction + encoded_data[i])
|
||||
|
||||
return decoded_data
|
||||
|
@ -5,3 +5,6 @@ matplotlib
|
||||
wandb
|
||||
pyyaml
|
||||
arithmetic_compressor
|
||||
pycallgraph2
|
||||
setuptools
|
||||
wheel
|
||||
|
66
train.py
66
train.py
@ -10,83 +10,90 @@ from data_processing import delta_encode, delta_decode, save_wav
|
||||
from utils import visualize_prediction, plot_delta_distribution
|
||||
from bitstream import ArithmeticEncoder
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
def pad_sequence(sequence, max_length):
|
||||
padded_seq = np.zeros((max_length, *sequence.shape[1:]))
|
||||
padded_seq[:sequence.shape[0], ...] = sequence
|
||||
return padded_seq
|
||||
|
||||
def evaluate_model(model, data, use_delta_encoding, encoder, sample_rate=19531, epoch=0):
|
||||
compression_ratios = []
|
||||
identical_count = 0
|
||||
all_deltas = []
|
||||
|
||||
model.eval()
|
||||
for file_data in data:
|
||||
file_data = torch.tensor(file_data, dtype=torch.float32).unsqueeze(1).to(device)
|
||||
for i, file_data in enumerate(data):
|
||||
file_data = torch.tensor(file_data, dtype=torch.float32).unsqueeze(1).to(model.device)
|
||||
encoded_data = model.encode(file_data.squeeze(1).cpu().numpy())
|
||||
encoder.build_model(encoded_data)
|
||||
compressed_data = encoder.encode(encoded_data)
|
||||
decompressed_data = encoder.decode(compressed_data, len(encoded_data))
|
||||
|
||||
# Check equivalence
|
||||
if use_delta_encoding:
|
||||
decompressed_data = delta_decode(decompressed_data)
|
||||
|
||||
# Ensure the lengths match
|
||||
min_length = min(len(file_data), len(decompressed_data))
|
||||
file_data = file_data[:min_length]
|
||||
decompressed_data = decompressed_data[:min_length]
|
||||
|
||||
identical = np.allclose(file_data.cpu().numpy(), decompressed_data, atol=1e-5)
|
||||
if identical:
|
||||
identical_count += 1
|
||||
|
||||
compression_ratio = len(file_data) / len(compressed_data)
|
||||
compression_ratios.append(compression_ratio)
|
||||
|
||||
# Compute and collect deltas
|
||||
predicted_data = model.decode(encoded_data)
|
||||
|
||||
predicted_data = model(torch.tensor(encoded_data, dtype=torch.float32).unsqueeze(1).to(model.device)).squeeze(1).detach().cpu().numpy()
|
||||
if use_delta_encoding:
|
||||
predicted_data = delta_decode(predicted_data)
|
||||
delta_data = [file_data[i].item() - predicted_data[i] for i in range(len(file_data))]
|
||||
|
||||
# Ensure predicted_data is a flat list of floats
|
||||
predicted_data = predicted_data[:min_length]
|
||||
|
||||
delta_data = [file_data[i].item() - predicted_data[i] for i in range(min_length)]
|
||||
all_deltas.extend(delta_data)
|
||||
|
||||
# Visualize prediction vs data vs error
|
||||
visualize_prediction(file_data.cpu().numpy(), predicted_data, delta_data, sample_rate)
|
||||
if i == (epoch % len(data)):
|
||||
visualize_prediction(file_data.cpu().numpy(), predicted_data, delta_data, sample_rate, epoch=epoch)
|
||||
|
||||
identical_percentage = (identical_count / len(data)) * 100
|
||||
|
||||
# Plot delta distribution
|
||||
delta_plot_path = plot_delta_distribution(all_deltas, epoch)
|
||||
wandb.log({"delta_distribution": wandb.Image(delta_plot_path)})
|
||||
wandb.log({"delta_distribution": wandb.Image(delta_plot_path)}, step=epoch)
|
||||
|
||||
return compression_ratios, identical_percentage
|
||||
|
||||
def train_model(model, train_data, test_data, epochs, batch_size, learning_rate, use_delta_encoding, encoder, eval_freq, save_path):
|
||||
"""Train the model."""
|
||||
wandb.init(project="wav-compression")
|
||||
criterion = nn.MSELoss()
|
||||
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
|
||||
best_test_score = float('inf')
|
||||
model = model.to(device)
|
||||
|
||||
|
||||
model.to(model.device)
|
||||
|
||||
max_length = max([len(seq) for seq in train_data])
|
||||
print(f"Max sequence length: {max_length}")
|
||||
|
||||
for epoch in range(epochs):
|
||||
model.train()
|
||||
total_loss = 0
|
||||
random.shuffle(train_data) # Shuffle data for varied batches
|
||||
random.shuffle(train_data)
|
||||
for i in range(0, len(train_data) - batch_size, batch_size):
|
||||
batch = train_data[i:i+batch_size]
|
||||
max_len = max(len(seq) for seq in batch)
|
||||
padded_batch = np.array([np.pad(seq, (0, max_len - len(seq))) for seq in batch], dtype=np.float32)
|
||||
inputs = torch.tensor(padded_batch[:, :-1], dtype=torch.float32).unsqueeze(2).to(device)
|
||||
targets = torch.tensor(padded_batch[:, 1:], dtype=torch.float32).unsqueeze(2).to(device)
|
||||
batch_data = [pad_sequence(np.array(train_data[j]), max_length) for j in range(i, i+batch_size)]
|
||||
batch_data = np.array(batch_data)
|
||||
inputs = torch.tensor(batch_data, dtype=torch.float32).unsqueeze(2).to(model.device)
|
||||
targets = torch.tensor(batch_data, dtype=torch.float32).unsqueeze(2).to(model.device)
|
||||
outputs = model(inputs)
|
||||
loss = criterion(outputs, targets)
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
total_loss += loss.item()
|
||||
|
||||
wandb.log({"epoch": epoch, "loss": total_loss})
|
||||
|
||||
wandb.log({"epoch": epoch, "loss": total_loss}, step=epoch)
|
||||
print(f'Epoch {epoch+1}/{epochs}, Loss: {total_loss}')
|
||||
|
||||
if (epoch + 1) % eval_freq == 0:
|
||||
# Evaluate on train and test data
|
||||
train_compression_ratios, train_identical_percentage = evaluate_model(model, train_data, use_delta_encoding, encoder, epoch=epoch)
|
||||
test_compression_ratios, test_identical_percentage = evaluate_model(model, test_data, use_delta_encoding, encoder, epoch=epoch)
|
||||
|
||||
# Log statistics
|
||||
wandb.log({
|
||||
"train_compression_ratio_mean": np.mean(train_compression_ratios),
|
||||
"train_compression_ratio_std": np.std(train_compression_ratios),
|
||||
@ -98,12 +105,11 @@ def train_model(model, train_data, test_data, epochs, batch_size, learning_rate,
|
||||
"test_compression_ratio_max": np.max(test_compression_ratios),
|
||||
"train_identical_percentage": train_identical_percentage,
|
||||
"test_identical_percentage": test_identical_percentage,
|
||||
})
|
||||
}, step=epoch)
|
||||
|
||||
print(f'Epoch {epoch+1}/{epochs}, Train Compression Ratio: Mean={np.mean(train_compression_ratios)}, Std={np.std(train_compression_ratios)}, Min={np.min(train_compression_ratios)}, Max={np.max(train_compression_ratios)}, Identical={train_identical_percentage}%')
|
||||
print(f'Epoch {epoch+1}/{epochs}, Test Compression Ratio: Mean={np.mean(test_compression_ratios)}, Std={np.std(test_compression_ratios)}, Min={np.min(test_compression_ratios)}, Max={np.max(test_compression_ratios)}, Identical={test_identical_percentage}%')
|
||||
|
||||
# Save model and encoder if new highscore on test data
|
||||
test_score = np.mean(test_compression_ratios)
|
||||
if test_score < best_test_score:
|
||||
best_test_score = test_score
|
||||
|
4
utils.py
4
utils.py
@ -14,7 +14,7 @@ def visualize_wav_data(sample_rate, data, title="WAV Data", num_points=None):
|
||||
plt.ylabel('Amplitude')
|
||||
plt.show()
|
||||
|
||||
def visualize_prediction(true_data, predicted_data, delta_data, sample_rate, num_points=None):
|
||||
def visualize_prediction(true_data, predicted_data, delta_data, sample_rate, num_points=None, epoch=None):
|
||||
"""Visualize the true data, predicted data, and deltas."""
|
||||
if num_points:
|
||||
true_data = true_data[:num_points]
|
||||
@ -46,7 +46,7 @@ def visualize_prediction(true_data, predicted_data, delta_data, sample_rate, num
|
||||
file_path = os.path.join(tmp_dir, f'prediction_plot_{np.random.randint(1e6)}.png')
|
||||
plt.savefig(file_path)
|
||||
plt.close()
|
||||
wandb.log({"Prediction vs True Data": wandb.Image(file_path)})
|
||||
wandb.log({"Prediction vs True Data": wandb.Image(file_path)}, step=epoch)
|
||||
|
||||
def plot_delta_distribution(deltas, epoch):
|
||||
"""Plot the distribution of deltas."""
|
||||
|
Loading…
Reference in New Issue
Block a user