From 359299b7ccb6518a7cb4d1e0259bf99dd3e8ba59 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Wed, 29 May 2024 21:11:02 +0200 Subject: [PATCH] Fix Bug: Rice dropping Sign of Deltas --- bitstream.py | 60 ++++++++++++++++++++++++++++++++++------------------ 1 file changed, 39 insertions(+), 21 deletions(-) diff --git a/bitstream.py b/bitstream.py index 3aa209d..e1055fe 100644 --- a/bitstream.py +++ b/bitstream.py @@ -11,21 +11,21 @@ class BaseEncoder(ABC): pass @abstractmethod - def decode(self, encoded_data, num_symbols): + def decode(self, encoded_data): pass @abstractmethod - def build_model(self, data): + def build_model(self, delta_samples): pass class IdentityEncoder(BaseEncoder): def encode(self, data): return data - def decode(self, encoded_data, num_symbols): + def decode(self, encoded_data): return encoded_data - def build_model(self, data): + def build_model(self, delta_samples): pass class ArithmeticEncoder(BaseEncoder): @@ -41,9 +41,9 @@ class ArithmeticEncoder(BaseEncoder): decoded_data = coder.decompress(encoded_data, num_symbols) return decoded_data - def build_model(self, data): + def build_model(self, delta_samples): # Convert data to list of tuples - data = [tuple(d) for d in data] + data = [tuple(d) for d in delta_samples] symbol_counts = {symbol: data.count(symbol) for symbol in set(data)} total_symbols = sum(symbol_counts.values()) probabilities = {symbol: count / total_symbols for symbol, count in symbol_counts.items()} @@ -53,7 +53,7 @@ class Bzip2Encoder(BaseEncoder): def encode(self, data): return bz2.compress(bytearray(data)) - def decode(self, encoded_data, num_symbols): + def decode(self, encoded_data): return list(bz2.decompress(encoded_data)) def build_model(self, data): @@ -61,7 +61,7 @@ class Bzip2Encoder(BaseEncoder): class BinomialHuffmanEncoder(BaseEncoder): def encode(self, data): - return ''.join(self.codebook[int(value)+512] for value in data) + return ''.join(self.codebook[int(value)+1024] for value in data) def decode(self, encoded_data): decoded_output = [] @@ -73,24 +73,33 @@ class BinomialHuffmanEncoder(BaseEncoder): current_node = current_node.right if current_node.left is None and current_node.right is None: - decoded_output.append(current_node.value) + decoded_output.append(current_node.value-1024) current_node = self.root return decoded_output - def _generate_codes(self, node, prefix="", codebook={}): - if node is not None: + def _generate_codes(self, root): + if root is None: + return {} + codebook = {} + stack = [(root, "")] + while stack: + node, prefix = stack.pop() if node.value is not None: codebook[node.value] = prefix - self._generate_codes(node.left, prefix + "0", codebook) - self._generate_codes(node.right, prefix + "1", codebook) + if node.right is not None: + stack.append((node.right, prefix + "1")) + if node.left is not None: + stack.append((node.left, prefix + "0")) return codebook - def build_model(self, data): - num_symbols = 2**10 + def build_model(self, delta_samples, adaptive=True): + num_symbols = 2**11 mean = (num_symbols - 1) / 2 std_dev = math.sqrt(num_symbols / 4) + if adaptive: + std_dev = np.std(delta_samples) class Node: def __init__(self, value, freq): @@ -122,18 +131,20 @@ class BinomialHuffmanEncoder(BaseEncoder): class RiceEncoder(BaseEncoder): def encode(self, data): data = np.array(data).astype(int) - q = data // self.m - r = data % self.m - encoded_data = [] - for qi, ri in zip(q, r): - encoded_data.append('1' * qi + '0' + format(ri, f'0{self.k}b')) + + for num in data: + num = self.zigzag_encode(num) + q = num // self.m + r = num % self.m + encoded_data.append('1' * q + '0' + format(r, f'0{self.k}b')) return ''.join(encoded_data) def decode(self, encoded_data): decoded_output = [] i = 0 + while i < len(encoded_data): q = 0 while encoded_data[i] == '1': @@ -142,10 +153,17 @@ class RiceEncoder(BaseEncoder): i += 1 # skip the '0' r = int(encoded_data[i:i + self.k], 2) i += self.k - decoded_output.append(q * self.m + r) + num = q * self.m + r + decoded_output.append(self.zigzag_decode(num)) return np.array(decoded_output) def build_model(self, data, k=3): self.k = k self.m = 1 << k + + def zigzag_encode(self, value): + return (value << 1) ^ (value >> 31) + + def zigzag_decode(self, value): + return (value >> 1) ^ -(value & 1) \ No newline at end of file