blob: a511e4bbd0ccae9fc6f535c6892b853f8234bd81 [file] [log] [blame]
# Copyright 2018 The Chromium OS Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
import json
import os
import shutil
import socket
import tempfile
import unittest
import servodutil
class TestServoScratch(unittest.TestCase):
def setUp(self):
"""Setup each test.
Prepare a ServoScratch on a temp directory & prepare a convenience entry.
"""
self._scratchdir = tempfile.mkdtemp()
self._scratch = servodutil.ServoScratch(self._scratchdir)
self._dport = 31234
self._dserials = ['9', '17']
self._dpid = 1827
# Commonly used test entry
self._entry = {'pid': self._dpid,
'serials': self._dserials,
'port': self._dport}
def tearDown(self):
"""Remove entry directory structure created during the test."""
shutil.rmtree(self._scratchdir)
def test_Init(self):
"""Verify servodutil creates the directory if it doesn't exist."""
new_scratchdir = os.path.join(self._scratchdir, 'testdir')
assert not os.path.exists(new_scratchdir)
# Only initialize the servodutil to ensure it creates |new_scratchdir|
servodutil.ServoScratch(new_scratchdir)
assert os.path.exists(new_scratchdir)
os.rmdir(new_scratchdir)
def test_AddEntry(self):
"""AddEntry creates & saves entry, and makes symlinks for each serial."""
self._scratch.AddEntry(port=self._dport, serials=self._dserials,
pid=self._dpid)
# Ensure port entry created
port_entry = os.path.join(self._scratchdir, str(self._dport))
assert os.path.exists(port_entry)
for serial in self._dserials:
serial_entry = os.path.join(self._scratchdir, serial)
# Ensure serial number entry created
assert os.path.exists(serial_entry)
# Ensure serial number entry is a link
assert os.path.islink(serial_entry)
# Ensure serial number entry is a link to the port entry
assert os.path.realpath(serial_entry) == port_entry
# Load entry
with open(port_entry, 'r') as entryf:
entry = json.load(entryf)
# Compare entry loaded with entry saved
assert entry == {'pid': self._dpid,
'serials': self._dserials,
'port': self._dport}
def test_AddEntryNonNumericalPort(self):
"""Verify AddEntry raises ServodUtilError when port can't be cast to int."""
port = 'hello'
with self.assertRaises(servodutil.ServodUtilError):
self._scratch.AddEntry(port, self._dserials, self._dpid)
def test_AddEntryNonNumericalPID(self):
"""Verify AddEntry raises ServodUtilError when pid can't be cast to int."""
pid = 'hello'
with self.assertRaises(servodutil.ServodUtilError):
self._scratch.AddEntry(self._dport, self._dserials, pid)
def test_AddEntryNonListlikeSerials(self):
"""Verify AddEntry raises ServodUtilError when serials is not iterable."""
serials = 17
with self.assertRaises(servodutil.ServodUtilError):
self._scratch.AddEntry(self._dport, serials, self._dpid)
def test_AddEntryTwice(self):
"""Verify AddEntry raises ServodUtilError when adding same entry twice."""
self._scratch.AddEntry(port=self._dport, serials=self._dserials,
pid=self._dpid)
# Ensure error when adding the same entry twice
with self.assertRaises(servodutil.ServodUtilError):
self._scratch.AddEntry(port=self._dport, serials=self._dserials,
pid=self._dpid)
# TODO(coconutruben): flesh out more to test equal port, equal serial,
# and potentially equal pid individually.
def _manually_add_entry(self, entry=None):
"""Manually add an entry. Uses self._entry if no entry supplied."""
if not entry:
entry = self._entry
files_added = []
entryfn = os.path.join(self._scratchdir, str(entry['port']))
# Manually add an entry and the symlinks
with open(entryfn, 'w') as entryf:
json.dump(entry, entryf)
files_added.append(entryfn)
for serial in entry['serials']:
linkfn = os.path.join(self._scratchdir, str(serial))
os.symlink(entryfn, linkfn)
files_added.append(linkfn)
return files_added
def test_RemoveEntry(self):
"""Verify RemoveEntry removes an entry fully (file + symlinks)."""
scratch = self._scratchdir
port = '9809'
serials = ['8000', '237300', 'lolaserial']
entry2 = {'pid': self._dpid,
'serials': serials,
'port': port}
entry_files = set(self._manually_add_entry())
entry2_files = set(self._manually_add_entry(entry2))
# Ensure there's a file for each serial, and one for the port
scratch_files = set([os.path.join(scratch, f) for f in os.listdir(scratch)])
assert scratch_files == (entry_files | entry2_files)
self._scratch.RemoveEntry(port)
# Ensure all files are removed
scratch_files = set([os.path.join(scratch, f) for f in os.listdir(scratch)])
assert scratch_files == entry_files
# Ensure the right files were removed
assert not os.path.exists(os.path.join(self._scratchdir, port))
for serial in serials:
assert not os.path.exists(os.path.join(self._scratchdir, serial))
def test_RemoveEntryBadIdentifier(self):
"""Verify RemoveEntry quietly ignores removing an unknown identifier."""
self._manually_add_entry()
self._scratch.RemoveEntry('badid')
def test_FindByIdPort(self):
"""Verify FindById works using ports."""
self._manually_add_entry()
entry_from_file = self._scratch.FindById(self._dport)
assert entry_from_file == self._entry
def test_FindByIdSerial(self):
"""Verify FindById works using serials."""
self._manually_add_entry()
for serial in self._dserials:
entry_from_file = self._scratch.FindById(serial)
assert entry_from_file == self._entry
def test_FindByIdBadJSON(self):
"""Verify FindById raises ServodUtilError when id points to invalid JSON."""
identifier = 'nonsense'
entryfn = os.path.join(self._scratchdir, identifier)
with open(entryfn, 'w') as entryf:
entryf.write('This is not JSON')
assert os.path.exists(entryfn)
with self.assertRaises(servodutil.ServodUtilError):
self._scratch.FindById(identifier)
# FindById removes invalid json files
assert not os.path.exists(entryfn)
def test_FindByIdBadId(self):
"""Verify FindById raises ServodUtilError when using an unknown id."""
self._manually_add_entry()
with self.assertRaises(servodutil.ServodUtilError):
self._scratch.FindById('badid')
def test_GetAllEntriesEmpty(self):
"""Verify GetAllEntries() doesn't break when there are no entries."""
# pylint: disable=g-explicit-bool-comparison
assert self._scratch._GetAllEntries() == []
def test_GetAllEntries(self):
"""Verify GetAllEntries() retrives all entries."""
# Dictionary to hold entries added
mentries = {}
mentries[9999] = {'port': 9999, 'serials': ['1999'], 'pid': 1234}
mentries[9998] = {'port': 9998, 'serials': ['1998'], 'pid': 1235}
mentries[9997] = {'port': 9997, 'serials': ['1997'], 'pid': 1236}
self._manually_add_entry(mentries[9999])
self._manually_add_entry(mentries[9998])
self._manually_add_entry(mentries[9997])
entries = self._scratch._GetAllEntries()
assert len(entries) == len(mentries)
for entry in entries:
assert mentries[entry['port']] == entry
def test_SanitizeNothingToDo(self):
"""Verify Sanitize does not remove active scratch entry."""
self._manually_add_entry()
testsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
testsock.bind(('localhost', self._dport))
prevfiles = os.listdir(self._scratchdir)
self._scratch._Sanitize()
postfiles = os.listdir(self._scratchdir)
# Port is attached so Sanitize should not wipe the entry
assert prevfiles == postfiles
testsock.close()
def test_SanitizeStaleEntry(self):
"""Verify that stale entries in servoscratch are removed."""
self._manually_add_entry()
self._scratch._Sanitize()
# The port is likely not connected to anything so Sanitize should consider
# this a stale entry and remove it.
assert not os.listdir(self._scratchdir)
def test_SanitizeMultipleStaleEntry(self):
"""Verify that stale entries in servoscratch are removed."""
self._manually_add_entry()
entry2 = {'pid': 12345,
'serials': ['this-is-not-a-serial'],
'port': 9888}
self._manually_add_entry(entry2)
self._scratch._Sanitize()
# The ports are likely not connected to anything so Sanitize should
# consider these stale entries and remove them.
assert not os.listdir(self._scratchdir)
def test_ConvertNameToMethodNoCamel(self):
"""Verify that strings without '-' only get capitalized."""
name = 'hi'
assert servodutil._ConvertNameToMethod(name) == 'Hi'
def test_ConvertNameToMethodDoubleCamel(self):
"""Verify proper camel-case conversion for strings containing '-'."""
name = 'hi-there-friends'
assert servodutil._ConvertNameToMethod(name) == 'HiThereFriends'
if __name__ == '__main__':
unittest.main()