blob: c8aa276ef5cfa8f6fb08c386cd75e4e8b5aed501 [file] [log] [blame]
# -*- coding: utf-8 -*-
"""
hpack/huffman_decoder
~~~~~~~~~~~~~~~~~~~~~
An implementation of a bitwise prefix tree specially built for decoding
Huffman-coded content where we already know the Huffman table.
"""
from .compat import to_byte, decode_hex
from .exceptions import HPACKDecodingError
def _pad_binary(bin_str, req_len=8):
"""
Given a binary string (returned by bin()), pad it to a full byte length.
"""
bin_str = bin_str[2:] # Strip the 0b prefix
return max(0, req_len - len(bin_str)) * '0' + bin_str
def _hex_to_bin_str(hex_string):
"""
Given a Python bytestring, returns a string representing those bytes in
unicode form.
"""
unpadded_bin_string_list = (bin(to_byte(c)) for c in hex_string)
padded_bin_string_list = map(_pad_binary, unpadded_bin_string_list)
bitwise_message = "".join(padded_bin_string_list)
return bitwise_message
class HuffmanDecoder(object):
"""
Decodes a Huffman-coded bytestream according to the Huffman table laid out
in the HPACK specification.
"""
class _Node(object):
def __init__(self, data):
self.data = data
self.mapping = {}
def __init__(self, huffman_code_list, huffman_code_list_lengths):
self.root = self._Node(None)
for index, (huffman_code, code_length) in enumerate(zip(huffman_code_list, huffman_code_list_lengths)):
self._insert(huffman_code, code_length, index)
def _insert(self, hex_number, hex_length, letter):
"""
Inserts a Huffman code point into the tree.
"""
hex_number = _pad_binary(bin(hex_number), hex_length)
cur_node = self.root
for digit in hex_number:
if digit not in cur_node.mapping:
cur_node.mapping[digit] = self._Node(None)
cur_node = cur_node.mapping[digit]
cur_node.data = letter
def decode(self, encoded_string):
"""
Decode the given Huffman coded string.
"""
number = _hex_to_bin_str(encoded_string)
cur_node = self.root
decoded_message = bytearray()
try:
for digit in number:
cur_node = cur_node.mapping[digit]
if cur_node.data is not None:
# If we get EOS, everything else is padding.
if cur_node.data == 256:
break
decoded_message.append(cur_node.data)
cur_node = self.root
except KeyError:
# We have a Huffman-coded string that doesn't match our trie. This
# is pretty bad: raise a useful exception.
raise HPACKDecodingError("Invalid Huffman-coded string received.")
return bytes(decoded_message)
class HuffmanEncoder(object):
"""
Encodes a string according to the Huffman encoding table defined in the
HPACK specification.
"""
def __init__(self, huffman_code_list, huffman_code_list_lengths):
self.huffman_code_list = huffman_code_list
self.huffman_code_list_lengths = huffman_code_list_lengths
def encode(self, bytes_to_encode):
"""
Given a string of bytes, encodes them according to the HPACK Huffman
specification.
"""
# If handed the empty string, just immediately return.
if not bytes_to_encode:
return b''
final_num = 0
final_int_len = 0
# Turn each byte into its huffman code. These codes aren't necessarily
# octet aligned, so keep track of how far through an octet we are. To
# handle this cleanly, just use a single giant integer.
for char in bytes_to_encode:
byte = to_byte(char)
bin_int_len = self.huffman_code_list_lengths[byte]
bin_int = self.huffman_code_list[byte] & (2 ** (bin_int_len + 1) - 1)
final_num <<= bin_int_len
final_num |= bin_int
final_int_len += bin_int_len
# Pad out to an octet with ones.
bits_to_be_padded = (8 - (final_int_len % 8)) % 8
final_num <<= bits_to_be_padded
final_num |= (1 << (bits_to_be_padded)) - 1
# Convert the number to hex and strip off the leading '0x' and the
# trailing 'L', if present.
final_num = hex(final_num)[2:].rstrip('L')
# If this is odd, prepend a zero.
final_num = '0' + final_num if len(final_num) % 2 != 0 else final_num
# This number should have twice as many digits as bytes. If not, we're
# missing some leading zeroes. Work out how many bytes we want and how
# many digits we have, then add the missing zero digits to the front.
total_bytes = (final_int_len + bits_to_be_padded) // 8
expected_digits = total_bytes * 2
if len(final_num) != expected_digits:
missing_digits = expected_digits - len(final_num)
final_num = ('0' * missing_digits) + final_num
return decode_hex(final_num)