2024-05-26 23:54:31 +02:00
|
|
|
import bz2, math
|
|
|
|
import heapq
|
2024-05-24 22:01:59 +02:00
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
from arithmetic_compressor import AECompressor
|
|
|
|
from arithmetic_compressor.models import StaticModel
|
2024-05-28 12:52:21 +02:00
|
|
|
import numpy as np
|
2024-05-24 22:01:59 +02:00
|
|
|
|
|
|
|
class BaseEncoder(ABC):
|
|
|
|
@abstractmethod
|
|
|
|
def encode(self, data):
|
|
|
|
pass
|
|
|
|
|
|
|
|
@abstractmethod
|
2024-05-29 21:11:02 +02:00
|
|
|
def decode(self, encoded_data):
|
2024-05-24 22:01:59 +02:00
|
|
|
pass
|
|
|
|
|
|
|
|
@abstractmethod
|
2024-05-29 21:11:02 +02:00
|
|
|
def build_model(self, delta_samples):
|
2024-05-24 22:01:59 +02:00
|
|
|
pass
|
|
|
|
|
2024-05-24 23:02:24 +02:00
|
|
|
class IdentityEncoder(BaseEncoder):
|
|
|
|
def encode(self, data):
|
|
|
|
return data
|
|
|
|
|
2024-05-29 21:11:02 +02:00
|
|
|
def decode(self, encoded_data):
|
2024-05-24 23:02:24 +02:00
|
|
|
return encoded_data
|
|
|
|
|
2024-05-29 21:11:02 +02:00
|
|
|
def build_model(self, delta_samples):
|
2024-05-24 23:02:24 +02:00
|
|
|
pass
|
|
|
|
|
2024-05-24 22:01:59 +02:00
|
|
|
class ArithmeticEncoder(BaseEncoder):
|
|
|
|
def encode(self, data):
|
|
|
|
if not hasattr(self, 'model'):
|
|
|
|
raise ValueError("Model not built. Call build_model(data) before encoding.")
|
|
|
|
coder = AECompressor(self.model)
|
|
|
|
compressed_data = coder.compress(data)
|
|
|
|
return compressed_data
|
|
|
|
|
|
|
|
def decode(self, encoded_data, num_symbols):
|
|
|
|
coder = AECompressor(self.model)
|
|
|
|
decoded_data = coder.decompress(encoded_data, num_symbols)
|
|
|
|
return decoded_data
|
|
|
|
|
2024-05-29 21:11:02 +02:00
|
|
|
def build_model(self, delta_samples):
|
2024-05-24 23:02:24 +02:00
|
|
|
# Convert data to list of tuples
|
2024-05-29 21:11:02 +02:00
|
|
|
data = [tuple(d) for d in delta_samples]
|
2024-05-24 22:01:59 +02:00
|
|
|
symbol_counts = {symbol: data.count(symbol) for symbol in set(data)}
|
|
|
|
total_symbols = sum(symbol_counts.values())
|
|
|
|
probabilities = {symbol: count / total_symbols for symbol, count in symbol_counts.items()}
|
|
|
|
self.model = StaticModel(probabilities)
|
2024-05-24 23:02:24 +02:00
|
|
|
|
|
|
|
class Bzip2Encoder(BaseEncoder):
|
|
|
|
def encode(self, data):
|
|
|
|
return bz2.compress(bytearray(data))
|
|
|
|
|
2024-05-29 21:11:02 +02:00
|
|
|
def decode(self, encoded_data):
|
2024-05-24 23:02:24 +02:00
|
|
|
return list(bz2.decompress(encoded_data))
|
|
|
|
|
|
|
|
def build_model(self, data):
|
|
|
|
pass
|
2024-05-26 23:54:31 +02:00
|
|
|
|
|
|
|
class BinomialHuffmanEncoder(BaseEncoder):
|
|
|
|
def encode(self, data):
|
2024-05-29 21:11:02 +02:00
|
|
|
return ''.join(self.codebook[int(value)+1024] for value in data)
|
2024-05-26 23:54:31 +02:00
|
|
|
|
|
|
|
def decode(self, encoded_data):
|
|
|
|
decoded_output = []
|
|
|
|
current_node = self.root
|
|
|
|
for bit in encoded_data:
|
|
|
|
if bit == '0':
|
|
|
|
current_node = current_node.left
|
|
|
|
else:
|
|
|
|
current_node = current_node.right
|
|
|
|
|
|
|
|
if current_node.left is None and current_node.right is None:
|
2024-05-29 21:11:02 +02:00
|
|
|
decoded_output.append(current_node.value-1024)
|
2024-05-26 23:54:31 +02:00
|
|
|
current_node = self.root
|
|
|
|
|
|
|
|
return decoded_output
|
|
|
|
|
2024-05-29 21:11:02 +02:00
|
|
|
def _generate_codes(self, root):
|
|
|
|
if root is None:
|
|
|
|
return {}
|
|
|
|
codebook = {}
|
|
|
|
stack = [(root, "")]
|
|
|
|
while stack:
|
|
|
|
node, prefix = stack.pop()
|
2024-05-26 23:54:31 +02:00
|
|
|
if node.value is not None:
|
|
|
|
codebook[node.value] = prefix
|
2024-05-29 21:11:02 +02:00
|
|
|
if node.right is not None:
|
|
|
|
stack.append((node.right, prefix + "1"))
|
|
|
|
if node.left is not None:
|
|
|
|
stack.append((node.left, prefix + "0"))
|
2024-05-26 23:54:31 +02:00
|
|
|
return codebook
|
|
|
|
|
2024-05-29 21:11:02 +02:00
|
|
|
def build_model(self, delta_samples, adaptive=True):
|
|
|
|
num_symbols = 2**11
|
2024-05-26 23:54:31 +02:00
|
|
|
|
|
|
|
mean = (num_symbols - 1) / 2
|
|
|
|
std_dev = math.sqrt(num_symbols / 4)
|
2024-05-29 21:11:02 +02:00
|
|
|
if adaptive:
|
|
|
|
std_dev = np.std(delta_samples)
|
2024-05-26 23:54:31 +02:00
|
|
|
|
|
|
|
class Node:
|
|
|
|
def __init__(self, value, freq):
|
|
|
|
self.value = value
|
|
|
|
self.freq = freq
|
|
|
|
self.left = None
|
|
|
|
self.right = None
|
|
|
|
|
|
|
|
def __lt__(self, other):
|
|
|
|
return self.freq < other.freq
|
|
|
|
|
|
|
|
# Build a min-heap
|
|
|
|
heap = [Node(x, (1 / (std_dev * math.sqrt(2 * math.pi))) * math.exp(-0.5 * ((x - mean) / std_dev) ** 2)) for x in range(num_symbols)]
|
|
|
|
heapq.heapify(heap)
|
|
|
|
|
|
|
|
# Merge nodes to build the Huffman tree
|
|
|
|
while len(heap) > 1:
|
|
|
|
left = heapq.heappop(heap)
|
|
|
|
right = heapq.heappop(heap)
|
|
|
|
merged = Node(None, left.freq + right.freq)
|
|
|
|
merged.left = left
|
|
|
|
merged.right = right
|
|
|
|
heapq.heappush(heap, merged)
|
|
|
|
|
|
|
|
# The root of the Huffman tree
|
|
|
|
self.root = heapq.heappop(heap)
|
|
|
|
self.codebook = self._generate_codes(self.root)
|
2024-05-28 12:52:21 +02:00
|
|
|
|
|
|
|
class RiceEncoder(BaseEncoder):
|
|
|
|
def encode(self, data):
|
|
|
|
data = np.array(data).astype(int)
|
|
|
|
encoded_data = []
|
2024-05-29 21:11:02 +02:00
|
|
|
|
|
|
|
for num in data:
|
|
|
|
num = self.zigzag_encode(num)
|
|
|
|
q = num // self.m
|
|
|
|
r = num % self.m
|
|
|
|
encoded_data.append('1' * q + '0' + format(r, f'0{self.k}b'))
|
2024-05-28 12:52:21 +02:00
|
|
|
|
|
|
|
return ''.join(encoded_data)
|
|
|
|
|
|
|
|
def decode(self, encoded_data):
|
|
|
|
decoded_output = []
|
|
|
|
i = 0
|
2024-05-29 21:11:02 +02:00
|
|
|
|
2024-05-28 12:52:21 +02:00
|
|
|
while i < len(encoded_data):
|
|
|
|
q = 0
|
|
|
|
while encoded_data[i] == '1':
|
|
|
|
q += 1
|
|
|
|
i += 1
|
|
|
|
i += 1 # skip the '0'
|
|
|
|
r = int(encoded_data[i:i + self.k], 2)
|
|
|
|
i += self.k
|
2024-05-29 21:11:02 +02:00
|
|
|
num = q * self.m + r
|
|
|
|
decoded_output.append(self.zigzag_decode(num))
|
2024-05-28 12:52:21 +02:00
|
|
|
|
|
|
|
return np.array(decoded_output)
|
|
|
|
|
|
|
|
def build_model(self, data, k=3):
|
|
|
|
self.k = k
|
|
|
|
self.m = 1 << k
|
2024-05-29 21:11:02 +02:00
|
|
|
|
|
|
|
def zigzag_encode(self, value):
|
|
|
|
return (value << 1) ^ (value >> 31)
|
|
|
|
|
|
|
|
def zigzag_decode(self, value):
|
|
|
|
return (value >> 1) ^ -(value & 1)
|