Add dynamic HWID updater support to shopfloor server.

BUG=chrome-os-partner:12568
TEST=shopfloor_unittest.py

Change-Id: I39314eda1450dec0cf5467052c4060a04008bd9f
Reviewed-on: https://gerrit.chromium.org/gerrit/30660
Tested-by: Jon Salz <jsalz@chromium.org>
Reviewed-by: Hung-Te Lin <hungte@chromium.org>
Commit-Ready: Jon Salz <jsalz@chromium.org>
diff --git a/factory_setup/shopfloor/__init__.py b/factory_setup/shopfloor/__init__.py
index d10873a..aa4a767 100644
--- a/factory_setup/shopfloor/__init__.py
+++ b/factory_setup/shopfloor/__init__.py
@@ -10,6 +10,7 @@
 """
 
 import csv
+import glob
 import logging
 import os
 import time
@@ -24,9 +25,14 @@
 EVENTS_DIR = 'events'
 REPORTS_DIR = 'reports'
 UPDATE_DIR = 'update'
+HWID_UPDATER_PATTERN = 'hwid_*'
 REGISTRATION_CODE_LOG_CSV = 'registration_code_log.csv'
 
 
+class ShopFloorException(Exception):
+  pass
+
+
 class ShopFloorBase(object):
   """Base class for shopfloor servers.
 
@@ -108,6 +114,38 @@
     """
     raise NotImplementedError('GetHWID')
 
+  def _GetHWIDUpdaterPath(self):
+    """Returns the path to HWID updater bundle, if available.
+
+    Returns:
+      The path to the file (or None).
+
+    Raises:
+      ShopFloorException if there are >1 HWID bundles available.
+    """
+    bundles = glob.glob(os.path.join(self.data_dir, HWID_UPDATER_PATTERN))
+    if not bundles:
+      return None
+
+    if len(bundles) > 1:
+      raise ShopFloorException('Multiple HWID bundles available: %s (please '
+                               'delete all but one)' % bundles)
+
+    return bundles[0]
+
+  def GetHWIDUpdater(self):
+    """Returns a HWID updater bundle, if available.
+
+    Returns:
+      The binary-encoded contents of a file named 'hwid_*' in the data
+      directory.  If there are no such files, returns None.
+
+    Raises:
+      ShopFloorException if there are >1 HWID bundles available.
+    """
+    path = self._GetHWIDUpdaterPath()
+    return open(path).read() if path else None
+
   def GetVPD(self, serial):
     """Returns VPD data to set (in dictionary format).
 
diff --git a/factory_setup/shopfloor_server.py b/factory_setup/shopfloor_server.py
index 64c1322..d97d7af 100755
--- a/factory_setup/shopfloor_server.py
+++ b/factory_setup/shopfloor_server.py
@@ -15,12 +15,14 @@
 '''
 
 
+import hashlib
 import imp
 import logging
 import optparse
 import os
 import shopfloor
 import SimpleXMLRPCServer
+from subprocess import Popen, PIPE
 
 
 _DEFAULT_SERVER_PORT = 8082
@@ -128,6 +130,17 @@
     logging.exception('Failed loading module: %s', options.module)
     exit(1)
 
+  # Find the HWID updater (if any).  Throw an exception if there are >1.
+  hwid_updater_path = instance._GetHWIDUpdaterPath()
+  if hwid_updater_path:
+    logging.info('Using HWID updater %s (md5sum %s)' % (
+        hwid_updater_path,
+        hashlib.md5(open(hwid_updater_path).read()).hexdigest()))
+  else:
+    logging.warn('No HWID updater id currently available; add a single '
+                 'file named %s to enable dynamic updating of HWIDs.' %
+                 os.path.join(options.data_dir, shopfloor.HWID_UPDATER_PATTERN))
+
   try:
     instance._StartBase()
     logging.debug('Starting RPC server...')
diff --git a/factory_setup/shopfloor_unittest.py b/factory_setup/shopfloor_unittest.py
index 66af11c..839df08 100755
--- a/factory_setup/shopfloor_unittest.py
+++ b/factory_setup/shopfloor_unittest.py
@@ -76,6 +76,19 @@
     self.assertRaises(xmlrpclib.Fault, self.proxy.GetHWID, 'CR001000')
     self.assertRaises(xmlrpclib.Fault, self.proxy.GetHWID, 'CR001026')
 
+  def testGetHWIDUpdater_None(self):
+    self.assertEquals(None, self.proxy.GetHWIDUpdater())
+
+  def testGetHWIDUpdater_One(self):
+    with open(os.path.join(self.data_dir, 'hwid_updater.sh'), 'w') as f:
+      f.write('foobar')
+    self.assertEquals('foobar', self.proxy.GetHWIDUpdater().data)
+
+  def testGetHWIDUpdater_Two(self):
+    for i in (1, 2):
+      open(os.path.join(self.data_dir, 'hwid_updater_%d.sh' % i), 'w').close()
+    self.assertRaises(xmlrpclib.Fault, self.proxy.GetHWIDUpdater)
+
   def testGetVPD(self):
     # VPD fields defined in simple.csv
     RO_FIELDS = ('keyboard_layout', 'initial_locale', 'initial_timezone')