| # -*- coding: utf-8 -*- |
| """ |
| hpack/huffman_decoder |
| ~~~~~~~~~~~~~~~~~~~~~ |
| |
| An implementation of a bitwise prefix tree specially built for decoding |
| Huffman-coded content where we already know the Huffman table. |
| """ |
| from .compat import to_byte, decode_hex |
| from .exceptions import HPACKDecodingError |
| |
| def _pad_binary(bin_str, req_len=8): |
| """ |
| Given a binary string (returned by bin()), pad it to a full byte length. |
| """ |
| bin_str = bin_str[2:] # Strip the 0b prefix |
| return max(0, req_len - len(bin_str)) * '0' + bin_str |
| |
| def _hex_to_bin_str(hex_string): |
| """ |
| Given a Python bytestring, returns a string representing those bytes in |
| unicode form. |
| """ |
| unpadded_bin_string_list = (bin(to_byte(c)) for c in hex_string) |
| padded_bin_string_list = map(_pad_binary, unpadded_bin_string_list) |
| bitwise_message = "".join(padded_bin_string_list) |
| return bitwise_message |
| |
| |
| class HuffmanDecoder(object): |
| """ |
| Decodes a Huffman-coded bytestream according to the Huffman table laid out |
| in the HPACK specification. |
| """ |
| class _Node(object): |
| def __init__(self, data): |
| self.data = data |
| self.mapping = {} |
| |
| def __init__(self, huffman_code_list, huffman_code_list_lengths): |
| self.root = self._Node(None) |
| for index, (huffman_code, code_length) in enumerate(zip(huffman_code_list, huffman_code_list_lengths)): |
| self._insert(huffman_code, code_length, index) |
| |
| def _insert(self, hex_number, hex_length, letter): |
| """ |
| Inserts a Huffman code point into the tree. |
| """ |
| hex_number = _pad_binary(bin(hex_number), hex_length) |
| cur_node = self.root |
| for digit in hex_number: |
| if digit not in cur_node.mapping: |
| cur_node.mapping[digit] = self._Node(None) |
| cur_node = cur_node.mapping[digit] |
| cur_node.data = letter |
| |
| def decode(self, encoded_string): |
| """ |
| Decode the given Huffman coded string. |
| """ |
| number = _hex_to_bin_str(encoded_string) |
| cur_node = self.root |
| decoded_message = bytearray() |
| |
| try: |
| for digit in number: |
| cur_node = cur_node.mapping[digit] |
| if cur_node.data is not None: |
| # If we get EOS, everything else is padding. |
| if cur_node.data == 256: |
| break |
| |
| decoded_message.append(cur_node.data) |
| cur_node = self.root |
| except KeyError: |
| # We have a Huffman-coded string that doesn't match our trie. This |
| # is pretty bad: raise a useful exception. |
| raise HPACKDecodingError("Invalid Huffman-coded string received.") |
| return bytes(decoded_message) |
| |
| |
| class HuffmanEncoder(object): |
| """ |
| Encodes a string according to the Huffman encoding table defined in the |
| HPACK specification. |
| """ |
| def __init__(self, huffman_code_list, huffman_code_list_lengths): |
| self.huffman_code_list = huffman_code_list |
| self.huffman_code_list_lengths = huffman_code_list_lengths |
| |
| def encode(self, bytes_to_encode): |
| """ |
| Given a string of bytes, encodes them according to the HPACK Huffman |
| specification. |
| """ |
| # If handed the empty string, just immediately return. |
| if not bytes_to_encode: |
| return b'' |
| |
| final_num = 0 |
| final_int_len = 0 |
| |
| # Turn each byte into its huffman code. These codes aren't necessarily |
| # octet aligned, so keep track of how far through an octet we are. To |
| # handle this cleanly, just use a single giant integer. |
| for char in bytes_to_encode: |
| byte = to_byte(char) |
| bin_int_len = self.huffman_code_list_lengths[byte] |
| bin_int = self.huffman_code_list[byte] & (2 ** (bin_int_len + 1) - 1) |
| final_num <<= bin_int_len |
| final_num |= bin_int |
| final_int_len += bin_int_len |
| |
| # Pad out to an octet with ones. |
| bits_to_be_padded = (8 - (final_int_len % 8)) % 8 |
| final_num <<= bits_to_be_padded |
| final_num |= (1 << (bits_to_be_padded)) - 1 |
| |
| # Convert the number to hex and strip off the leading '0x' and the |
| # trailing 'L', if present. |
| final_num = hex(final_num)[2:].rstrip('L') |
| |
| # If this is odd, prepend a zero. |
| final_num = '0' + final_num if len(final_num) % 2 != 0 else final_num |
| |
| # This number should have twice as many digits as bytes. If not, we're |
| # missing some leading zeroes. Work out how many bytes we want and how |
| # many digits we have, then add the missing zero digits to the front. |
| total_bytes = (final_int_len + bits_to_be_padded) // 8 |
| expected_digits = total_bytes * 2 |
| |
| if len(final_num) != expected_digits: |
| missing_digits = expected_digits - len(final_num) |
| final_num = ('0' * missing_digits) + final_num |
| |
| return decode_hex(final_num) |