py3: make logdog lib python3 compatible
R=iannucci, yuanjunh
Bug: 1227763
Change-Id: Ic59a70ea7baa6d2814b7ea08d43e51cfb1985edf
Reviewed-on: https://chromium-review.googlesource.com/c/infra/luci/luci-py/+/3027858
Commit-Queue: Robbie Iannucci <iannucci@chromium.org>
Auto-Submit: Yiwei Zhang <yiwzhang@google.com>
Reviewed-by: Robbie Iannucci <iannucci@chromium.org>
NOKEYCHECK=True
GitOrigin-RevId: abc575286a0deb372048cbd260a9f148706fb639
diff --git a/stream.py b/stream.py
index 6da02d4..d526ef9 100644
--- a/stream.py
+++ b/stream.py
@@ -19,6 +19,9 @@
from ctypes import GetLastError
+_PY2 = sys.version_info[0] == 2
+
+
_StreamParamsBase = collections.namedtuple(
'_StreamParamsBase', ('name', 'type', 'content_type', 'tags'))
@@ -27,7 +30,7 @@
#
# See "ProtocolFrameHeaderMagic" in:
# <luci-go>/logdog/client/butlerlib/streamproto
-BUTLER_MAGIC = 'BTLR1\x1e'
+BUTLER_MAGIC = b'BTLR1\x1e'
class StreamParams(_StreamParamsBase):
@@ -59,7 +62,8 @@
raise ValueError('Invalid type (%s)' % (self.type,))
if self.tags is not None:
- if not isinstance(self.tags, collections.Mapping):
+ if not isinstance(
+ self.tags, collections.Mapping if _PY2 else collections.abc.Mapping):
raise ValueError('Invalid tags type (%s)' % (self.tags,))
for k, v in self.tags.items():
streamname.validate_tag(k, v)
@@ -192,6 +196,29 @@
return self._fd.close()
+ class _TextStream(_BasicStream):
+ """Extends _BasicStream, ensuring data written is UTF-8 text."""
+
+ def __init__(self, stream_client, params, fd):
+ super(StreamClient._TextStream, self).__init__(stream_client, params, fd)
+ self._fd = fd
+
+ def write(self, data):
+ if _PY2 and isinstance(data, str):
+ # byte string is unfortunately accepted in py2 because of
+ # undifferentiated usage of `str` and `unicode` but it should be
+ # discontinued in py3. User should switch to binary stream instead
+ # if there's a need to write bytes.
+ return self._fd.write(data)
+ elif _PY2 and isinstance(data, unicode):
+ return self._fd.write(data.encode('utf-8'))
+ elif not _PY2 and isinstance(data, str):
+ return self._fd.write(data.encode('utf-8'))
+ else:
+ raise ValueError(
+ 'expect str, got %r that is type %s' % (data, type(data),))
+
+
class _DatagramStream(_StreamBase):
"""Wraps a stream object to write length-prefixed datagrams."""
@@ -348,12 +375,12 @@
are not valid.
"""
self._register_new_stream(params.name)
- params_json = params.to_json()
+ params_bytes = params.to_json().encode('utf-8')
fobj = self._connect_raw()
fobj.write(BUTLER_MAGIC)
- varint.write_uvarint(fobj, len(params_json))
- fobj.write(params_json)
+ varint.write_uvarint(fobj, len(params_bytes))
+ fobj.write(params_bytes)
return fobj
@contextlib.contextmanager
@@ -399,7 +426,7 @@
type=StreamParams.TEXT,
content_type=content_type,
tags=tags)
- return self._BasicStream(self, params, self.new_connection(params))
+ return self._TextStream(self, params, self.new_connection(params))
@contextlib.contextmanager
def binary(self, name, **kwargs):
diff --git a/streamname.py b/streamname.py
index 8e9bf33..91037b2 100644
--- a/streamname.py
+++ b/streamname.py
@@ -54,7 +54,7 @@
def normalize_segment(seg, prefix=None):
- """Given a string (str|unicode), mutate it into a valid segment name (str).
+ """Given a string, mutate it into a valid segment name.
This operates by replacing invalid segment name characters with underscores
(_) when encountered.
@@ -88,15 +88,11 @@
if _SEGMENT_RE.match(seg) is None:
raise AssertionError('Normalized segment is still invalid: %r' % seg)
- # v could be of type unicode. As a valid stream name contains only ascii
- # characters, it is safe to transcode v to ascii encoding (become str type).
- if isinstance(seg, unicode):
- return seg.encode('ascii')
return seg
def normalize(v, prefix=None):
- """Given a string (str|unicode), mutate it into a valid stream name (str).
+ """Given a string, mutate it into a valid stream name.
This operates by replacing invalid stream name characters with underscores (_)
when encountered.
@@ -162,14 +158,12 @@
try:
validate_stream_name(self.prefix)
except ValueError as e:
- raise ValueError('Invalid prefix component [%s]: %s' % (
- self.prefix, e.message,))
+ raise ValueError('Invalid prefix component [%s]: %s' % (self.prefix, e,))
try:
validate_stream_name(self.name)
except ValueError as e:
- raise ValueError('Invalid name component [%s]: %s' % (
- self.name, e.message,))
+ raise ValueError('Invalid name component [%s]: %s' % (self.name, e,))
def __str__(self):
return '%s/+/%s' % (self.prefix, self.name)
diff --git a/tests/stream_test.py b/tests/stream_test.py
index eea9575..67c8c19 100755
--- a/tests/stream_test.py
+++ b/tests/stream_test.py
@@ -1,13 +1,15 @@
#!/usr/bin/env vpython
+# -*- coding: utf-8 -*-
# Copyright 2016 The LUCI Authors. All rights reserved.
# Use of this source code is governed under the Apache License, Version 2.0
# that can be found in the LICENSE file.
+from io import BufferedReader, BytesIO
+
import json
import os
import sys
import unittest
-import StringIO
ROOT_DIR = os.path.dirname(os.path.abspath(os.path.join(
__file__.decode(sys.getfilesystemencoding()),
@@ -62,7 +64,7 @@
class _TestStreamClientConnection(object):
def __init__(self):
- self.buffer = StringIO.StringIO()
+ self.buffer = BytesIO()
self.closed = False
def _assert_not_closed(self):
@@ -78,7 +80,7 @@
self.closed = True
def interpret(self):
- data = StringIO.StringIO(self.buffer.getvalue())
+ data = BytesIO(self.buffer.getvalue())
magic = data.read(len(stream.BUTLER_MAGIC))
if magic != stream.BUTLER_MAGIC:
raise ValueError('Invalid magic value ([%s] != [%s])' % (
@@ -108,10 +110,10 @@
@staticmethod
def _split_datagrams(value):
- sio = StringIO.StringIO(value)
- while sio.pos < sio.len:
- size_prefix, _ = varint.read_uvarint(sio)
- data = sio.read(size_prefix)
+ br = BufferedReader(BytesIO(value))
+ while br.peek(1):
+ size_prefix, _ = varint.read_uvarint(br)
+ data = br.read(size_prefix)
if len(data) != size_prefix:
raise ValueError('Expected %d bytes, but only got %d' % (
size_prefix, len(data)))
@@ -134,14 +136,15 @@
self.assertEqual(
fd.get_viewer_url(),
'https://example.appspot.com/v/?s=test%2Ffoo%2Fbar%2F%2B%2Fmystream')
- fd.write('text\nstream\nlines')
+ fd.write('text\nstream\nlines\n')
+ fd.write(u'š\nš\nš')
conn = client.last_conn
self.assertTrue(conn.closed)
header, data = conn.interpret()
self.assertEqual(header, {'name': 'mystream', 'type': 'text'})
- self.assertEqual(data, 'text\nstream\nlines')
+ self.assertEqual(data.decode('utf-8'), u'text\nstream\nlines\nš\nš\nš')
def testTextStreamWithParams(self):
client = self._registry.create('test:value')
@@ -166,7 +169,7 @@
'contentType': 'foo/bar',
'tags': {'foo': 'bar', 'baz': 'qux'},
})
- self.assertEqual(data, 'text!')
+ self.assertEqual(data.decode('utf-8'), u'text!')
def testBinaryStream(self):
client = self._registry.create('test:value',
@@ -180,14 +183,14 @@
self.assertEqual(
fd.get_viewer_url(),
'https://example.appspot.com/v/?s=test%2Ffoo%2Fbar%2F%2B%2Fmystream')
- fd.write('\x60\x0d\xd0\x65')
+ fd.write(b'\x60\x0d\xd0\x65')
conn = client.last_conn
self.assertTrue(conn.closed)
header, data = conn.interpret()
self.assertEqual(header, {'name': 'mystream', 'type': 'binary'})
- self.assertEqual(data, '\x60\x0d\xd0\x65')
+ self.assertEqual(data, b'\x60\x0d\xd0\x65')
def testDatagramStream(self):
client = self._registry.create('test:value',
@@ -201,10 +204,10 @@
self.assertEqual(
fd.get_viewer_url(),
'https://example.appspot.com/v/?s=test%2Ffoo%2Fbar%2F%2B%2Fmystream')
- fd.send('datagram0')
- fd.send('dg1')
- fd.send('')
- fd.send('dg3')
+ fd.send(b'datagram0')
+ fd.send(b'dg1')
+ fd.send(b'')
+ fd.send(b'dg3')
conn = client.last_conn
self.assertTrue(conn.closed)
@@ -212,7 +215,7 @@
header, data = conn.interpret()
self.assertEqual(header, {'name': 'mystream', 'type': 'datagram'})
self.assertEqual(list(self._split_datagrams(data)),
- ['datagram0', 'dg1', '', 'dg3'])
+ [b'datagram0', b'dg1', b'', b'dg3'])
def testStreamWithoutPrefixCannotGenerateUrls(self):
client = self._registry.create('test:value',
@@ -257,7 +260,7 @@
header, data = conn.interpret()
self.assertEqual(header, {'name': 'mystream', 'type': 'text'})
- self.assertEqual(data, 'Using a text stream.')
+ self.assertEqual(data.decode('utf-8'), u'Using a text stream.')
if __name__ == '__main__':
diff --git a/tests/streamname_test.py b/tests/streamname_test.py
index d2800f0..45329c0 100755
--- a/tests/streamname_test.py
+++ b/tests/streamname_test.py
@@ -6,7 +6,6 @@
import os
import sys
import unittest
-import StringIO
ROOT_DIR = os.path.dirname(os.path.abspath(os.path.join(
__file__.decode(sys.getfilesystemencoding()),
diff --git a/tests/varint_test.py b/tests/varint_test.py
index 527d7b0..0f1a51c 100755
--- a/tests/varint_test.py
+++ b/tests/varint_test.py
@@ -3,11 +3,13 @@
# Use of this source code is governed under the Apache License, Version 2.0
# that can be found in the LICENSE file.
+from io import BytesIO
+
+import binascii
import itertools
import os
import sys
import unittest
-import StringIO
ROOT_DIR = os.path.dirname(os.path.abspath(os.path.join(
__file__.decode(sys.getfilesystemencoding()),
@@ -28,9 +30,9 @@
(0x81, b'\x81\x01'),
(0x18080, b'\x80\x81\x06'),
):
- sio = StringIO.StringIO()
- count = varint.write_uvarint(sio, base)
- act = sio.getvalue()
+ bytesIO = BytesIO()
+ count = varint.write_uvarint(bytesIO, base)
+ act = bytesIO.getvalue()
self.assertEqual(act, exp,
"Encoding for %d (%r) doesn't match expected (%r)" % (base, act, exp))
@@ -41,22 +43,21 @@
def testVarintEncodeDecode(self):
seed = (b'\x00', b'\x01', b'\x55', b'\x7F', b'\x80', b'\x81', b'\xff')
for perm in itertools.permutations(seed):
- perm = ''.join(perm).encode('hex')
+ perm = b''.join(perm)
while len(perm) > 0:
- exp = int(perm.encode('hex'), 16)
-
- sio = StringIO.StringIO()
- count = varint.write_uvarint(sio, exp)
- sio.seek(0)
- act, count = varint.read_uvarint(sio)
+ exp = int(binascii.hexlify(perm), 16)
+ bytesIO = BytesIO()
+ count = varint.write_uvarint(bytesIO, exp)
+ bytesIO.seek(0)
+ act, count = varint.read_uvarint(bytesIO)
self.assertEqual(act, exp,
"Decoded %r (%d) doesn't match expected (%d)" % (
- sio.getvalue().encode('hex'), act, exp))
- self.assertEqual(count, len(sio.getvalue()),
+ binascii.hexlify(bytesIO.getvalue()), act, exp))
+ self.assertEqual(count, len(bytesIO.getvalue()),
"Decoded length (%d) doesn't match expected (%d)" % (
- count, len(sio.getvalue())))
+ count, len(bytesIO.getvalue())))
if perm == 0:
break
diff --git a/varint.py b/varint.py
index 7bf3cca..c40bb5f 100644
--- a/varint.py
+++ b/varint.py
@@ -3,6 +3,7 @@
# that can be found in the LICENSE file.
import os
+import struct
import sys
@@ -28,7 +29,7 @@
if val > 0:
byte |= 0b10000000
- w.write(chr(byte))
+ w.write(struct.pack('B', byte))
count += 1
return count
@@ -55,7 +56,7 @@
if len(byte) == 0:
raise ValueError('UVarint was not terminated')
- byte = ord(byte)
+ byte = struct.unpack('B', byte)[0]
result |= ((byte & 0b01111111) << (7 * count))
count += 1
if byte & 0b10000000 == 0: