py3: make logdog lib python3 compatible

R=iannucci, yuanjunh

Bug: 1227763
Change-Id: Ic59a70ea7baa6d2814b7ea08d43e51cfb1985edf
Reviewed-on: https://chromium-review.googlesource.com/c/infra/luci/luci-py/+/3027858
Commit-Queue: Robbie Iannucci <iannucci@chromium.org>
Auto-Submit: Yiwei Zhang <yiwzhang@google.com>
Reviewed-by: Robbie Iannucci <iannucci@chromium.org>
NOKEYCHECK=True
GitOrigin-RevId: abc575286a0deb372048cbd260a9f148706fb639
diff --git a/stream.py b/stream.py
index 6da02d4..d526ef9 100644
--- a/stream.py
+++ b/stream.py
@@ -19,6 +19,9 @@
   from ctypes import GetLastError
 
 
+_PY2 = sys.version_info[0] == 2
+
+
 _StreamParamsBase = collections.namedtuple(
     '_StreamParamsBase', ('name', 'type', 'content_type', 'tags'))
 
@@ -27,7 +30,7 @@
 #
 # See "ProtocolFrameHeaderMagic" in:
 # <luci-go>/logdog/client/butlerlib/streamproto
-BUTLER_MAGIC = 'BTLR1\x1e'
+BUTLER_MAGIC = b'BTLR1\x1e'
 
 
 class StreamParams(_StreamParamsBase):
@@ -59,7 +62,8 @@
       raise ValueError('Invalid type (%s)' % (self.type,))
 
     if self.tags is not None:
-      if not isinstance(self.tags, collections.Mapping):
+      if not isinstance(
+          self.tags, collections.Mapping if _PY2 else collections.abc.Mapping):
         raise ValueError('Invalid tags type (%s)' % (self.tags,))
       for k, v in self.tags.items():
         streamname.validate_tag(k, v)
@@ -192,6 +196,29 @@
       return self._fd.close()
 
 
+  class _TextStream(_BasicStream):
+    """Extends _BasicStream, ensuring data written is UTF-8 text."""
+
+    def __init__(self, stream_client, params, fd):
+      super(StreamClient._TextStream, self).__init__(stream_client, params, fd)
+      self._fd = fd
+
+    def write(self, data):
+      if _PY2 and isinstance(data, str):
+        # byte string is unfortunately accepted in py2 because of
+        # undifferentiated usage of `str` and `unicode` but it should be
+        # discontinued in py3. User should switch to binary stream instead
+        # if there's a need to write bytes.
+        return self._fd.write(data)
+      elif _PY2 and isinstance(data, unicode):
+        return self._fd.write(data.encode('utf-8'))
+      elif not _PY2 and isinstance(data, str):
+        return self._fd.write(data.encode('utf-8'))
+      else:
+        raise ValueError(
+            'expect str, got %r that is type %s' % (data, type(data),))
+
+
   class _DatagramStream(_StreamBase):
     """Wraps a stream object to write length-prefixed datagrams."""
 
@@ -348,12 +375,12 @@
       are not valid.
     """
     self._register_new_stream(params.name)
-    params_json = params.to_json()
+    params_bytes = params.to_json().encode('utf-8')
 
     fobj = self._connect_raw()
     fobj.write(BUTLER_MAGIC)
-    varint.write_uvarint(fobj, len(params_json))
-    fobj.write(params_json)
+    varint.write_uvarint(fobj, len(params_bytes))
+    fobj.write(params_bytes)
     return fobj
 
   @contextlib.contextmanager
@@ -399,7 +426,7 @@
         type=StreamParams.TEXT,
         content_type=content_type,
         tags=tags)
-    return self._BasicStream(self, params, self.new_connection(params))
+    return self._TextStream(self, params, self.new_connection(params))
 
   @contextlib.contextmanager
   def binary(self, name, **kwargs):
diff --git a/streamname.py b/streamname.py
index 8e9bf33..91037b2 100644
--- a/streamname.py
+++ b/streamname.py
@@ -54,7 +54,7 @@
 
 
 def normalize_segment(seg, prefix=None):
-  """Given a string (str|unicode), mutate it into a valid segment name (str).
+  """Given a string, mutate it into a valid segment name.
 
   This operates by replacing invalid segment name characters with underscores
   (_) when encountered.
@@ -88,15 +88,11 @@
   if _SEGMENT_RE.match(seg) is None:
     raise AssertionError('Normalized segment is still invalid: %r' % seg)
 
-  # v could be of type unicode. As a valid stream name contains only ascii
-  # characters, it is safe to transcode v to ascii encoding (become str type).
-  if isinstance(seg, unicode):
-    return seg.encode('ascii')
   return seg
 
 
 def normalize(v, prefix=None):
-  """Given a string (str|unicode), mutate it into a valid stream name (str).
+  """Given a string, mutate it into a valid stream name.
 
   This operates by replacing invalid stream name characters with underscores (_)
   when encountered.
@@ -162,14 +158,12 @@
     try:
       validate_stream_name(self.prefix)
     except ValueError as e:
-      raise ValueError('Invalid prefix component [%s]: %s' % (
-          self.prefix, e.message,))
+      raise ValueError('Invalid prefix component [%s]: %s' % (self.prefix, e,))
 
     try:
       validate_stream_name(self.name)
     except ValueError as e:
-      raise ValueError('Invalid name component [%s]: %s' % (
-          self.name, e.message,))
+      raise ValueError('Invalid name component [%s]: %s' % (self.name, e,))
 
   def __str__(self):
     return '%s/+/%s' % (self.prefix, self.name)
diff --git a/tests/stream_test.py b/tests/stream_test.py
index eea9575..67c8c19 100755
--- a/tests/stream_test.py
+++ b/tests/stream_test.py
@@ -1,13 +1,15 @@
 #!/usr/bin/env vpython
+# -*- coding: utf-8 -*-
 # Copyright 2016 The LUCI Authors. All rights reserved.
 # Use of this source code is governed under the Apache License, Version 2.0
 # that can be found in the LICENSE file.
 
+from io import BufferedReader, BytesIO
+
 import json
 import os
 import sys
 import unittest
-import StringIO
 
 ROOT_DIR = os.path.dirname(os.path.abspath(os.path.join(
     __file__.decode(sys.getfilesystemencoding()),
@@ -62,7 +64,7 @@
   class _TestStreamClientConnection(object):
 
     def __init__(self):
-      self.buffer = StringIO.StringIO()
+      self.buffer = BytesIO()
       self.closed = False
 
     def _assert_not_closed(self):
@@ -78,7 +80,7 @@
       self.closed = True
 
     def interpret(self):
-      data = StringIO.StringIO(self.buffer.getvalue())
+      data = BytesIO(self.buffer.getvalue())
       magic = data.read(len(stream.BUTLER_MAGIC))
       if magic != stream.BUTLER_MAGIC:
         raise ValueError('Invalid magic value ([%s] != [%s])' % (
@@ -108,10 +110,10 @@
 
   @staticmethod
   def _split_datagrams(value):
-    sio = StringIO.StringIO(value)
-    while sio.pos < sio.len:
-      size_prefix, _ = varint.read_uvarint(sio)
-      data = sio.read(size_prefix)
+    br = BufferedReader(BytesIO(value))
+    while br.peek(1):
+      size_prefix, _ = varint.read_uvarint(br)
+      data = br.read(size_prefix)
       if len(data) != size_prefix:
         raise ValueError('Expected %d bytes, but only got %d' % (
             size_prefix, len(data)))
@@ -134,14 +136,15 @@
       self.assertEqual(
           fd.get_viewer_url(),
           'https://example.appspot.com/v/?s=test%2Ffoo%2Fbar%2F%2B%2Fmystream')
-      fd.write('text\nstream\nlines')
+      fd.write('text\nstream\nlines\n')
+      fd.write(u'šŸ˜„\nšŸ˜„\nšŸ˜„')
 
     conn = client.last_conn
     self.assertTrue(conn.closed)
 
     header, data = conn.interpret()
     self.assertEqual(header, {'name': 'mystream', 'type': 'text'})
-    self.assertEqual(data, 'text\nstream\nlines')
+    self.assertEqual(data.decode('utf-8'), u'text\nstream\nlines\nšŸ˜„\nšŸ˜„\nšŸ˜„')
 
   def testTextStreamWithParams(self):
     client = self._registry.create('test:value')
@@ -166,7 +169,7 @@
         'contentType': 'foo/bar',
          'tags': {'foo': 'bar', 'baz': 'qux'},
     })
-    self.assertEqual(data, 'text!')
+    self.assertEqual(data.decode('utf-8'), u'text!')
 
   def testBinaryStream(self):
     client = self._registry.create('test:value',
@@ -180,14 +183,14 @@
       self.assertEqual(
           fd.get_viewer_url(),
           'https://example.appspot.com/v/?s=test%2Ffoo%2Fbar%2F%2B%2Fmystream')
-      fd.write('\x60\x0d\xd0\x65')
+      fd.write(b'\x60\x0d\xd0\x65')
 
     conn = client.last_conn
     self.assertTrue(conn.closed)
 
     header, data = conn.interpret()
     self.assertEqual(header, {'name': 'mystream', 'type': 'binary'})
-    self.assertEqual(data, '\x60\x0d\xd0\x65')
+    self.assertEqual(data, b'\x60\x0d\xd0\x65')
 
   def testDatagramStream(self):
     client = self._registry.create('test:value',
@@ -201,10 +204,10 @@
       self.assertEqual(
           fd.get_viewer_url(),
           'https://example.appspot.com/v/?s=test%2Ffoo%2Fbar%2F%2B%2Fmystream')
-      fd.send('datagram0')
-      fd.send('dg1')
-      fd.send('')
-      fd.send('dg3')
+      fd.send(b'datagram0')
+      fd.send(b'dg1')
+      fd.send(b'')
+      fd.send(b'dg3')
 
     conn = client.last_conn
     self.assertTrue(conn.closed)
@@ -212,7 +215,7 @@
     header, data = conn.interpret()
     self.assertEqual(header, {'name': 'mystream', 'type': 'datagram'})
     self.assertEqual(list(self._split_datagrams(data)),
-        ['datagram0', 'dg1', '', 'dg3'])
+        [b'datagram0', b'dg1', b'', b'dg3'])
 
   def testStreamWithoutPrefixCannotGenerateUrls(self):
     client = self._registry.create('test:value',
@@ -257,7 +260,7 @@
 
     header, data = conn.interpret()
     self.assertEqual(header, {'name': 'mystream', 'type': 'text'})
-    self.assertEqual(data, 'Using a text stream.')
+    self.assertEqual(data.decode('utf-8'), u'Using a text stream.')
 
 
 if __name__ == '__main__':
diff --git a/tests/streamname_test.py b/tests/streamname_test.py
index d2800f0..45329c0 100755
--- a/tests/streamname_test.py
+++ b/tests/streamname_test.py
@@ -6,7 +6,6 @@
 import os
 import sys
 import unittest
-import StringIO
 
 ROOT_DIR = os.path.dirname(os.path.abspath(os.path.join(
     __file__.decode(sys.getfilesystemencoding()),
diff --git a/tests/varint_test.py b/tests/varint_test.py
index 527d7b0..0f1a51c 100755
--- a/tests/varint_test.py
+++ b/tests/varint_test.py
@@ -3,11 +3,13 @@
 # Use of this source code is governed under the Apache License, Version 2.0
 # that can be found in the LICENSE file.
 
+from io import BytesIO
+
+import binascii
 import itertools
 import os
 import sys
 import unittest
-import StringIO
 
 ROOT_DIR = os.path.dirname(os.path.abspath(os.path.join(
     __file__.decode(sys.getfilesystemencoding()),
@@ -28,9 +30,9 @@
         (0x81, b'\x81\x01'),
         (0x18080, b'\x80\x81\x06'),
     ):
-      sio = StringIO.StringIO()
-      count = varint.write_uvarint(sio, base)
-      act = sio.getvalue()
+      bytesIO = BytesIO()
+      count = varint.write_uvarint(bytesIO, base)
+      act = bytesIO.getvalue()
 
       self.assertEqual(act, exp,
           "Encoding for %d (%r) doesn't match expected (%r)" % (base, act, exp))
@@ -41,22 +43,21 @@
   def testVarintEncodeDecode(self):
     seed = (b'\x00', b'\x01', b'\x55', b'\x7F', b'\x80', b'\x81', b'\xff')
     for perm in itertools.permutations(seed):
-      perm = ''.join(perm).encode('hex')
+      perm = b''.join(perm)
 
       while len(perm) > 0:
-        exp = int(perm.encode('hex'), 16)
-
-        sio = StringIO.StringIO()
-        count = varint.write_uvarint(sio, exp)
-        sio.seek(0)
-        act, count = varint.read_uvarint(sio)
+        exp = int(binascii.hexlify(perm), 16)
+        bytesIO = BytesIO()
+        count = varint.write_uvarint(bytesIO, exp)
+        bytesIO.seek(0)
+        act, count = varint.read_uvarint(bytesIO)
 
         self.assertEqual(act, exp,
             "Decoded %r (%d) doesn't match expected (%d)" % (
-                sio.getvalue().encode('hex'), act, exp))
-        self.assertEqual(count, len(sio.getvalue()),
+                binascii.hexlify(bytesIO.getvalue()), act, exp))
+        self.assertEqual(count, len(bytesIO.getvalue()),
             "Decoded length (%d) doesn't match expected (%d)" % (
-                count, len(sio.getvalue())))
+                count, len(bytesIO.getvalue())))
 
         if perm == 0:
           break
diff --git a/varint.py b/varint.py
index 7bf3cca..c40bb5f 100644
--- a/varint.py
+++ b/varint.py
@@ -3,6 +3,7 @@
 # that can be found in the LICENSE file.
 
 import os
+import struct
 import sys
 
 
@@ -28,7 +29,7 @@
     if val > 0:
       byte |= 0b10000000
 
-    w.write(chr(byte))
+    w.write(struct.pack('B', byte))
     count += 1
   return count
 
@@ -55,7 +56,7 @@
     if len(byte) == 0:
       raise ValueError('UVarint was not terminated')
 
-    byte = ord(byte)
+    byte = struct.unpack('B', byte)[0]
     result |= ((byte & 0b01111111) << (7 * count))
     count += 1
     if byte & 0b10000000 == 0: