Fix Bug: Rice dropping Sign of Deltas
This commit is contained in:
parent
1158817cfc
commit
359299b7cc
60
bitstream.py
60
bitstream.py
@ -11,21 +11,21 @@ class BaseEncoder(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def decode(self, encoded_data, num_symbols):
|
||||
def decode(self, encoded_data):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def build_model(self, data):
|
||||
def build_model(self, delta_samples):
|
||||
pass
|
||||
|
||||
class IdentityEncoder(BaseEncoder):
|
||||
def encode(self, data):
|
||||
return data
|
||||
|
||||
def decode(self, encoded_data, num_symbols):
|
||||
def decode(self, encoded_data):
|
||||
return encoded_data
|
||||
|
||||
def build_model(self, data):
|
||||
def build_model(self, delta_samples):
|
||||
pass
|
||||
|
||||
class ArithmeticEncoder(BaseEncoder):
|
||||
@ -41,9 +41,9 @@ class ArithmeticEncoder(BaseEncoder):
|
||||
decoded_data = coder.decompress(encoded_data, num_symbols)
|
||||
return decoded_data
|
||||
|
||||
def build_model(self, data):
|
||||
def build_model(self, delta_samples):
|
||||
# Convert data to list of tuples
|
||||
data = [tuple(d) for d in data]
|
||||
data = [tuple(d) for d in delta_samples]
|
||||
symbol_counts = {symbol: data.count(symbol) for symbol in set(data)}
|
||||
total_symbols = sum(symbol_counts.values())
|
||||
probabilities = {symbol: count / total_symbols for symbol, count in symbol_counts.items()}
|
||||
@ -53,7 +53,7 @@ class Bzip2Encoder(BaseEncoder):
|
||||
def encode(self, data):
|
||||
return bz2.compress(bytearray(data))
|
||||
|
||||
def decode(self, encoded_data, num_symbols):
|
||||
def decode(self, encoded_data):
|
||||
return list(bz2.decompress(encoded_data))
|
||||
|
||||
def build_model(self, data):
|
||||
@ -61,7 +61,7 @@ class Bzip2Encoder(BaseEncoder):
|
||||
|
||||
class BinomialHuffmanEncoder(BaseEncoder):
|
||||
def encode(self, data):
|
||||
return ''.join(self.codebook[int(value)+512] for value in data)
|
||||
return ''.join(self.codebook[int(value)+1024] for value in data)
|
||||
|
||||
def decode(self, encoded_data):
|
||||
decoded_output = []
|
||||
@ -73,24 +73,33 @@ class BinomialHuffmanEncoder(BaseEncoder):
|
||||
current_node = current_node.right
|
||||
|
||||
if current_node.left is None and current_node.right is None:
|
||||
decoded_output.append(current_node.value)
|
||||
decoded_output.append(current_node.value-1024)
|
||||
current_node = self.root
|
||||
|
||||
return decoded_output
|
||||
|
||||
def _generate_codes(self, node, prefix="", codebook={}):
|
||||
if node is not None:
|
||||
def _generate_codes(self, root):
|
||||
if root is None:
|
||||
return {}
|
||||
codebook = {}
|
||||
stack = [(root, "")]
|
||||
while stack:
|
||||
node, prefix = stack.pop()
|
||||
if node.value is not None:
|
||||
codebook[node.value] = prefix
|
||||
self._generate_codes(node.left, prefix + "0", codebook)
|
||||
self._generate_codes(node.right, prefix + "1", codebook)
|
||||
if node.right is not None:
|
||||
stack.append((node.right, prefix + "1"))
|
||||
if node.left is not None:
|
||||
stack.append((node.left, prefix + "0"))
|
||||
return codebook
|
||||
|
||||
def build_model(self, data):
|
||||
num_symbols = 2**10
|
||||
def build_model(self, delta_samples, adaptive=True):
|
||||
num_symbols = 2**11
|
||||
|
||||
mean = (num_symbols - 1) / 2
|
||||
std_dev = math.sqrt(num_symbols / 4)
|
||||
if adaptive:
|
||||
std_dev = np.std(delta_samples)
|
||||
|
||||
class Node:
|
||||
def __init__(self, value, freq):
|
||||
@ -122,18 +131,20 @@ class BinomialHuffmanEncoder(BaseEncoder):
|
||||
class RiceEncoder(BaseEncoder):
|
||||
def encode(self, data):
|
||||
data = np.array(data).astype(int)
|
||||
q = data // self.m
|
||||
r = data % self.m
|
||||
|
||||
encoded_data = []
|
||||
for qi, ri in zip(q, r):
|
||||
encoded_data.append('1' * qi + '0' + format(ri, f'0{self.k}b'))
|
||||
|
||||
for num in data:
|
||||
num = self.zigzag_encode(num)
|
||||
q = num // self.m
|
||||
r = num % self.m
|
||||
encoded_data.append('1' * q + '0' + format(r, f'0{self.k}b'))
|
||||
|
||||
return ''.join(encoded_data)
|
||||
|
||||
def decode(self, encoded_data):
|
||||
decoded_output = []
|
||||
i = 0
|
||||
|
||||
while i < len(encoded_data):
|
||||
q = 0
|
||||
while encoded_data[i] == '1':
|
||||
@ -142,10 +153,17 @@ class RiceEncoder(BaseEncoder):
|
||||
i += 1 # skip the '0'
|
||||
r = int(encoded_data[i:i + self.k], 2)
|
||||
i += self.k
|
||||
decoded_output.append(q * self.m + r)
|
||||
num = q * self.m + r
|
||||
decoded_output.append(self.zigzag_decode(num))
|
||||
|
||||
return np.array(decoded_output)
|
||||
|
||||
def build_model(self, data, k=3):
|
||||
self.k = k
|
||||
self.m = 1 << k
|
||||
|
||||
def zigzag_encode(self, value):
|
||||
return (value << 1) ^ (value >> 31)
|
||||
|
||||
def zigzag_decode(self, value):
|
||||
return (value >> 1) ^ -(value & 1)
|
Loading…
Reference in New Issue
Block a user