blob: 4d4a227ac789c3e1350bd823b9e867b518a2246f [file] [log] [blame]
# Copyright 2017 The Chromium OS Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
"""Unittests for feature_extractor_utils.py."""
# pylint: disable=g-bad-import-order
import os
import unittest
import mock
from lib import constants
from lib import feature_extractor_utils
class TestUtilFunctions(unittest.TestCase):
"""Ensures that utility functions are doing the right jobs."""
def testGetInOutDirForCLTrainLocal(self):
"""Test GetInOutDir() function for local CL training run."""
result = feature_extractor_utils.GetInOutDir(True,
constants.FEATURE_TYPE_CL)
expect = (os.path.join(constants.ROOT_DIR,
constants.ROOT_DATA_FOLDER,
constants.USAGE_TRAIN,
constants.FEATURE_TYPE_CL,
'raw_data'),
os.path.join(constants.ROOT_DIR,
constants.ROOT_DATA_FOLDER,
constants.USAGE_TRAIN,
constants.FEATURE_TYPE_CL,
'feature'))
self.assertEqual(result, expect)
def testGetInOutDirForStagePredictLocal(self):
"""Test GetInOutDir() function for local stage prediction run."""
result = feature_extractor_utils.GetInOutDir(False,
constants.FEATURE_TYPE_STAGE)
expect = (os.path.join(constants.ROOT_DIR,
constants.ROOT_DATA_FOLDER,
constants.USAGE_PREDICT,
constants.FEATURE_TYPE_STAGE,
'raw_data'),
os.path.join(constants.ROOT_DIR,
constants.ROOT_DATA_FOLDER,
constants.USAGE_PREDICT,
constants.FEATURE_TYPE_STAGE,
'feature'))
self.assertEqual(result, expect)
def testGetInOutDirForStagePredictOnCloud(self):
"""Test GetInOutDir() function for stage prediction run on cloud."""
result = feature_extractor_utils.GetInOutDir(False,
constants.FEATURE_TYPE_STAGE,
False)
expect = (os.path.join(constants.ROOT_DATA_FOLDER,
constants.USAGE_PREDICT,
constants.FEATURE_TYPE_STAGE,
'raw_data'),
os.path.join(constants.ROOT_DATA_FOLDER,
constants.USAGE_PREDICT,
constants.FEATURE_TYPE_STAGE,
'feature'))
self.assertEqual(result, expect)
@mock.patch('lib.feature_extractor_utils.os.listdir')
def testGetFilenamesLocal(self, mock_listdir):
"""Test GetFilenames() function if running locally."""
mock_listdir.return_value = ['2_2', '3_3', '4_4', '1_1', '0_1']
filenames = feature_extractor_utils.GetFilenames('.', None)
expected_filenames = set(['./0_1', './1_1', './2_2', './3_3', './4_4'])
self.assertEqual(filenames, expected_filenames)
if __name__ == '__main__':
unittest.main()