Support for new BinomialHuffman
This commit is contained in:
parent
8576f5b741
commit
102ddb8c85
46
main.py
46
main.py
@ -47,6 +47,7 @@ class SpikeRunner(Slate_Runner):
|
||||
latent_size = slate.consume(config, 'latent_projector.latent_size')
|
||||
input_size = slate.consume(config, 'feature_extractor.input_size')
|
||||
region_latent_size = slate.consume(config, 'middle_out.region_latent_size')
|
||||
self.delta_shift = slate.consume(config, 'predictor.delta_shift', True)
|
||||
device = slate.consume(training_config, 'device', 'auto')
|
||||
if device == 'auto':
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
@ -71,10 +72,10 @@ class SpikeRunner(Slate_Runner):
|
||||
self.batch_size = slate.consume(training_config, 'batch_size')
|
||||
self.num_batches = slate.consume(training_config, 'num_batches')
|
||||
self.learning_rate = slate.consume(training_config, 'learning_rate')
|
||||
self.eval_freq = slate.consume(training_config, 'eval_freq')
|
||||
self.eval_freq = slate.consume(training_config, 'eval_freq', -1)
|
||||
self.save_path = slate.consume(training_config, 'save_path')
|
||||
self.peer_gradients_factor = float(slate.consume(training_config, 'peer_gradients_factor', 1.0))
|
||||
self.value_scale = slate.consume(training_config, 'value_scale')
|
||||
self.value_scale = slate.consume(training_config, 'value_scale', 1.0)
|
||||
|
||||
# Evaluation parameter
|
||||
self.full_compression = slate.consume(config, 'evaluation.full_compression', default=False)
|
||||
@ -93,8 +94,7 @@ class SpikeRunner(Slate_Runner):
|
||||
self.encoder = RiceEncoder()
|
||||
else:
|
||||
raise Exception('No such Encoder')
|
||||
|
||||
self.encoder.build_model(self.all_data, **slate.consume(config, 'bitstream_encoding'))
|
||||
self.bitstream_encoder_config = slate.consume(config, 'bitstream_encoding')
|
||||
|
||||
# Optimizer
|
||||
self.optimizer = torch.optim.Adam(list(self.projector.parameters()) + list(self.middle_out.parameters()) + list(self.predictor.parameters()), lr=self.learning_rate)
|
||||
@ -154,11 +154,13 @@ class SpikeRunner(Slate_Runner):
|
||||
# Stack the segments to form the batch
|
||||
stacked_segment = torch.stack([inputs] + peer_segments).to(device)
|
||||
stacked_segments.append(stacked_segment)
|
||||
target = lead_data[i + self.input_size + 1]
|
||||
target = lead_data[i + self.input_size]
|
||||
targets.append(target)
|
||||
last = lead_data[i + self.input_size]
|
||||
last = lead_data[i + self.input_size - 1]
|
||||
lasts.append(last)
|
||||
|
||||
las = torch.tensor(lasts, dtype=torch.float32).unsqueeze(-1).to(device)
|
||||
|
||||
inp = torch.stack(stacked_segments) / self.value_scale
|
||||
feat = self.feat(inp)
|
||||
latents = self.projector(feat)
|
||||
@ -178,12 +180,14 @@ class SpikeRunner(Slate_Runner):
|
||||
region_latent = self.middle_out(my_latent, peer_latents, torch.stack(peer_metrics))
|
||||
prediction = self.predictor(region_latent)*self.value_scale
|
||||
|
||||
if self.delta_shift:
|
||||
prediction = prediction + las
|
||||
|
||||
# Calculate loss and backpropagate
|
||||
tar = torch.tensor(targets, dtype=torch.float32).unsqueeze(-1).to(device)
|
||||
las = torch.tensor(lasts, dtype=torch.float32).unsqueeze(-1).numpy()
|
||||
loss = self.criterion(prediction, tar)
|
||||
err = np.sum(np.abs(prediction.cpu().detach().numpy() - tar.cpu().detach().numpy()))
|
||||
derr = np.sum(np.abs(las - tar.cpu().detach().numpy()))
|
||||
derr = np.sum(np.abs(las.cpu().detach().numpy() - tar.cpu().detach().numpy()))
|
||||
rel = err / np.sum(tar.cpu().detach().numpy())
|
||||
total_loss += loss.item()
|
||||
derrs.append(derr/np.prod(tar.size()).item())
|
||||
@ -226,9 +230,7 @@ class SpikeRunner(Slate_Runner):
|
||||
all_true = []
|
||||
all_predicted = []
|
||||
all_deltas = []
|
||||
compression_ratios = []
|
||||
exact_matches = 0
|
||||
total_sequences = 0
|
||||
all_steps = []
|
||||
|
||||
with torch.no_grad():
|
||||
min_length = min([len(seq) for seq in self.test_data])
|
||||
@ -261,11 +263,13 @@ class SpikeRunner(Slate_Runner):
|
||||
|
||||
stacked_segment = torch.stack([inputs] + peer_segments).to(device)
|
||||
stacked_segments.append(stacked_segment)
|
||||
target = lead_data[i + self.input_size + 1]
|
||||
target = lead_data[i + self.input_size]
|
||||
targets.append(target)
|
||||
last = lead_data[i + self.input_size]
|
||||
last = lead_data[i + self.input_size - 1]
|
||||
lasts.append(last)
|
||||
|
||||
las = torch.tensor(lasts, dtype=torch.float32).unsqueeze(-1).to(device)
|
||||
|
||||
inp = torch.stack(stacked_segments) / self.value_scale
|
||||
feat = self.feat(inp)
|
||||
latents = self.projector(feat)
|
||||
@ -276,11 +280,15 @@ class SpikeRunner(Slate_Runner):
|
||||
region_latent = self.middle_out(my_latent, peer_latents, torch.stack(peer_metrics))
|
||||
prediction = self.predictor(region_latent) * self.value_scale
|
||||
|
||||
if self.delta_shift:
|
||||
prediction = prediction + las
|
||||
|
||||
tar = torch.tensor(targets, dtype=torch.float32).unsqueeze(-1).to(device)
|
||||
las = torch.tensor(lasts, dtype=torch.float32).unsqueeze(-1).numpy()
|
||||
loss = self.criterion(prediction, tar)
|
||||
err = np.sum(np.abs(prediction.cpu().detach().numpy() - tar.cpu().detach().numpy()))
|
||||
derr = np.sum(np.abs(las - tar.cpu().detach().numpy()))
|
||||
delta = prediction.cpu().detach().numpy() - tar.cpu().detach().numpy()
|
||||
err = np.sum(np.abs(delta))
|
||||
derr = np.sum(np.abs(las.cpu().detach().numpy() - tar.cpu().detach().numpy()))
|
||||
step = las.cpu().detach().numpy() - tar.cpu().detach().numpy()
|
||||
rel = err / np.sum(tar.cpu().detach().numpy())
|
||||
total_loss += loss.item()
|
||||
derrs.append(derr / np.prod(tar.size()).item())
|
||||
@ -289,9 +297,11 @@ class SpikeRunner(Slate_Runner):
|
||||
|
||||
all_true.extend(tar.cpu().numpy())
|
||||
all_predicted.extend(prediction.cpu().numpy())
|
||||
all_deltas.extend((tar.cpu().numpy() - prediction.cpu().numpy()).tolist())
|
||||
all_deltas.extend(delta.tolist())
|
||||
all_steps.extend(step.tolist())
|
||||
|
||||
if self.full_compression:
|
||||
self.encoder.build_model(delta_samples=delta, **self.bitstream_encoder_config)
|
||||
raw_l = len(refuckify(np.concatenate(all_true)).astype(np.int16))*16
|
||||
comp_l = len(self.encoder.encode(np.concatenate(all_deltas)))
|
||||
ratio = raw_l / comp_l
|
||||
@ -308,7 +318,7 @@ class SpikeRunner(Slate_Runner):
|
||||
|
||||
# Visualize predictions
|
||||
#visualize_prediction(all_true, all_predicted, all_deltas, epoch=epoch, num_points=1953, name='0.1s')
|
||||
img = visualize_prediction(all_true, all_predicted, all_deltas, epoch=epoch, num_points=195)
|
||||
img = visualize_prediction(all_true, all_predicted, all_deltas, all_steps, epoch=epoch, num_points=195)
|
||||
try:
|
||||
wandb.log({f"Prediction vs True Data 0.01s": wandb.Image(img)}, step=epoch)
|
||||
except:
|
||||
|
Loading…
Reference in New Issue
Block a user