Optofidelity: PWM Compensation

This change adds PWM compensation by estimating the characteristics of
the PWM artifacts (pulse response and frequency) at the time of
screen calibration.
During the processing, this compensation is applied to all normalized
frames.
More accurate PWM compensation can be done if the bottom of the screen
will always display white in a test case. This active PWM compensation
reads the PWM artifacts from that area and compensates them over the
whole screen. Since not all tests show a white area at the bottom, this
has to be enabled separately.

BUG=chromium:536633
TEST=test cases included

Change-Id: I835524d58221cfc60e492af4014124f3a544311e
Reviewed-on: https://chromium-review.googlesource.com/336525
Commit-Ready: Dennis Kempin <denniskempin@chromium.org>
Tested-by: Dennis Kempin <denniskempin@chromium.org>
Reviewed-by: Dennis Kempin <denniskempin@chromium.org>
diff --git a/optofidelity/optofidelity/detection/_calibrated_frame.py b/optofidelity/optofidelity/detection/_calibrated_frame.py
index 37c4955..dfbe2f9 100644
--- a/optofidelity/optofidelity/detection/_calibrated_frame.py
+++ b/optofidelity/optofidelity/detection/_calibrated_frame.py
@@ -73,7 +73,7 @@
     if self.camera_space_prev_frame is None:
       return None
     normalized = self._screen_calibration.NormalizeFrame(
-        self.screen_space_prev_frame)
+        self.screen_space_prev_frame, pwm_compensation=True)
     return Filter.Truncate(normalized)
 
   @const_property
@@ -81,7 +81,7 @@
     """:returns np.ndarray: color normalized frame in screen space."""
     self._require_calibration()
     normalized = self._screen_calibration.NormalizeFrame(
-        self.screen_space_frame)
+        self.screen_space_frame, pwm_compensation=True)
     thresh = self.BRIGHTNESS_OUT_OF_RANGE_THRESH
     if (np.any(normalized > 1 + thresh) or np.any(normalized < -thresh)):
       raise ValueError("Brightness is outside of normalized range.")
@@ -129,3 +129,24 @@
   def _require_calibration(self):
     if not self._screen_calibration:
       raise ValueError("screen_calibration is required for this operation.")
+
+  def MeasurePWMProfile(self, image=None, height=20):
+    if image is None:
+      image = self.screen_space_normalized
+    top = image.shape[0] - height
+    measurement_area = Shape.FromRectangle(image.shape, top=top)
+    return measurement_area.CalculateProfile(image)
+
+  def MeasurePWM(self, image=None, height=20):
+    pwm_profile = self.MeasurePWMProfile(image, height)
+    return np.tile(pwm_profile, (image.shape[0], 1))
+
+  def CompensatePWM(self, image=None, height=20):
+    pwm = self.MeasurePWM(image, height)
+    pwm[pwm < 0.1] = 0.1
+    return image / pwm
+
+  def CompensatePWMProfile(self, profile, image=None, height=20):
+    pwm = self.MeasurePWMProfile(image, height)
+    pwm[pwm < 0.1] = 0.1
+    return profile / pwm
diff --git a/optofidelity/optofidelity/detection/screen_calibration.py b/optofidelity/optofidelity/detection/screen_calibration.py
index 748cc3f..bfc5a4c 100644
--- a/optofidelity/optofidelity/detection/screen_calibration.py
+++ b/optofidelity/optofidelity/detection/screen_calibration.py
@@ -6,12 +6,14 @@
 import logging
 import sys
 
+from matplotlib import pyplot
 from safetynet import Tuple, TypecheckMeta
 import cv2
 import numpy as np
 import skimage.morphology as morphology
 import skimage.transform as transform
 
+from optofidelity.util import nputil
 from optofidelity.videoproc import DebugView, Filter, Shape
 
 _log = logging.getLogger(__name__)
@@ -59,7 +61,7 @@
   """Compensate for PWM if the range of colors in the white reference image is
      larger than this value."""
 
-  def __new__(cls, black_frame, white_frame):
+  def __new__(cls, black_frame, white_frame, pwm_pulse=None):
     """Creates new screen calibration from black and white video frames.
 
     If the white_frame is all white, the calibration will not do any
@@ -71,6 +73,15 @@
     self = super(ScreenCalibration, cls).__new__(cls)
     self.is_identity = np.allclose(white_frame, 1.0)
 
+    self.pwm_profile = None
+    self.pwm_pulse = None
+    self.pwm_pulse_length = None
+
+    if pwm_pulse is not None:
+      self.pwm_pulse = pwm_pulse
+      self.pwm_profile = np.tile(pwm_pulse, 10)
+      self.pwm_pulse_length = len(pwm_pulse)
+
     if self.is_identity:
       self.shape = Shape(np.ones(white_frame.shape, dtype=np.bool))
       self._rectification_transform = None
@@ -98,7 +109,6 @@
       features_image = self._GetFeaturesImage(self.black_frame)
       self._good_features = cv2.goodFeaturesToTrack(features_image, 200, 0.01,
                                                     30)
-      self._CompensatePWM()
     return self
 
   @property
@@ -123,8 +133,21 @@
     :param bool debug
     :returns ScreenCalibration
     """
-    ref_images = []
     with video_reader.PrefetchEnabled():
+      # Skim through video to find rough location of the screen.
+      prev_frame = None
+      delta_accum = np.zeros(video_reader.frame_shape)
+      for i, frame in video_reader.Frames(step=10):
+        if prev_frame is None:
+          prev_frame = frame
+          continue
+        delta = np.abs(frame - prev_frame)
+        delta_accum += delta
+      shapes = Shape.Shapes(delta_accum > 0.9)
+      screen_shape = max(shapes, key=lambda s: s.area)
+
+      # Walk frame-by-frame to find reference image for black and white screen.
+      ref_indices = []
       frames = video_reader.Frames()
 
       i, prev_frame = frames.next()
@@ -134,7 +157,7 @@
       for i, frame in frames:
         # Calculate the mid-range value of the inter-frame difference.
         # This reliably describes the general direction of change.
-        diff = frame - prev_frame
+        diff = (frame - prev_frame) * screen_shape.mask
         midrange = Filter.StableMidRange(diff)
         _log.debug("%4d: midrange=%.2f, dir=%d, dur=%d", i, midrange,
                   change_direction, change_duration)
@@ -151,32 +174,143 @@
             # Wait until the screen is flipping it's change direction, this is
             # a reference frame we want to pick.
             if change_duration > 3:
-              ref_images.append(prev_frame)
-              if len(ref_images) >= 2:
+              ref_indices.append(i - 1)
+              if len(ref_indices) >= 2:
                 break
             change_direction = 0
             change_duration = 0
 
         prev_frame = frame
 
-    if len(ref_images) < 2:
+    if len(ref_indices) < 2:
       raise Exception("Cannot find flashing screen")
 
-    if np.mean(ref_images[0]) > np.mean(ref_images[1]):
-      return cls(ref_images[1], ref_images[0])
-    else:
-      return cls(ref_images[0], ref_images[1])
+    # Identify black and white screen
+    black_ref = video_reader.FrameAt(ref_indices[0])
+    white_ref = video_reader.FrameAt(ref_indices[1])
+    white_index = ref_indices[1]
+    if np.mean(black_ref) > np.mean(white_ref):
+      (black_ref, white_ref) = (white_ref, black_ref)
+      white_index = ref_indices[0]
 
-  def _CompensatePWM(self):
-    profile = np.mean(self.white_reference, 0)
-    profile_deriv = np.diff(profile)
+    # Collect frames for PWM calibration
+    (height, width) = video_reader.frame_shape
+    pwm_frames = np.zeros((height, width, 6), dtype=np.float)
+    for i in range(6):
+      pwm_frames[:,:,i] = video_reader.FrameAt(white_index + i)
+    white_ref = np.max(pwm_frames, axis=2)
 
-    deriv_range = np.max(profile_deriv) - np.min(profile_deriv)
-    if deriv_range > self.PWM_COMPENSATION_LEVEL:
-      _log.warn("Compensating for PWM (range=%.4f)" % deriv_range)
-      self.white_reference[:] = np.min(profile)
-    else:
-      _log.warn("No PWM compensation (range=%.4f)" % deriv_range)
+    calibration = cls(black_ref, white_ref)
+    calibration.CalibratePWM(pwm_frames)
+    return calibration
+
+  def CalibratePWM(self, pwm_frames, debug=False):
+    def GetPWMProfile(pwm_frames, index):
+      """Returns deviation from white reference as a profile."""
+      screen_space_pwm = self.CameraToScreenSpace(pwm_frames[:, :, index])
+      delta = self.white_reference - screen_space_pwm
+      profile = np.mean(delta, 0)
+      return profile
+
+    def ListPWMPulseResponses(pwm_frames):
+      """Yields snippets of PWM pulses detected in the pwm_frames."""
+      for i in range(0, pwm_frames.shape[2]):
+        profile = GetPWMProfile(pwm_frames, i)
+
+        # No PWM artifacts? Great!
+        if np.max(profile) - np.min(profile) < 0.2:
+          return
+
+        # The PWM will show up as pulses from 0 to 1, when the profile is
+        # normalized. The state changes will show us when
+        normalized = profile / np.max(profile)
+        for start, end in nputil.FindPeaks(normalized, min_value=0.9,
+                                           max_slope=0.1, max_mid_range=0.1):
+          if start < 10:
+            continue
+          start = start - 10
+          end = np.min((end + 10, len(profile)))
+          pulse = profile[start:end]
+          yield pulse
+
+    def EstimateAveragePulseResponse(pulse_list):
+      """Aligns all pulse responses and returns the averaged pulse response."""
+      max_pulse_length = np.max([len(p) for p in pulse_list])
+      pulse0 = np.copy(pulse_list[0])
+      pulse0.resize(max_pulse_length)
+
+      # Align all pulses to first pulse
+      aligned_pulses = []
+      for pulse in pulse_list:
+        shift = nputil.EstimateShift(pulse0, pulse, "mae", debug=debug)
+        matched = nputil.AlignArrays(pulse0, pulse, shift)
+        aligned_pulses.append(matched)
+
+      # Slightly overestimate. Overcorrection is less of a problem than under
+      # correction.
+      estimate = np.max(np.asarray(aligned_pulses), axis=0)
+      estimate = estimate * 1.1
+
+      # Force the beginning and end of the response to 0
+      base_level = np.max((estimate[:2], estimate[-2:]))
+      estimate = estimate - base_level
+      estimate[estimate < 0] = 0
+
+      if debug:
+        pyplot.figure()
+        for pulse in aligned_pulses:
+          pyplot.plot(pulse)
+        pyplot.plot(estimate, "o")
+        pyplot.show()
+      return pulse
+
+    def EstimatePulseLength(pulse, pwm_frames):
+      """Estimates the frequency of PWM pulses through cross-correlation."""
+      lengths = []
+      for i in range(0, pwm_frames.shape[2]):
+        profile = GetPWMProfile(pwm_frames, i)
+
+        # Correlation shows peaks at the frequency at which pulse repeats in
+        # profile.
+        corr = np.correlate(pulse, profile, "full")
+        corr = corr / np.max(corr)
+
+        # Find strong maximums
+        maximums = nputil.FindLocalExtremes(corr)
+        filtered_maximums = filter(lambda e: corr[e] > 0.5, maximums)
+        sorted_maximums = sorted(filtered_maximums)
+
+        # Distance between maximums is the pulse length
+        lengths.extend(np.diff(sorted_maximums))
+
+        if debug:
+          pyplot.figure()
+          pyplot.subplot(1, 2, 0)
+          pyplot.plot(profile)
+          pyplot.plot(pulse)
+          pyplot.subplot(1, 2, 1)
+          pyplot.plot(corr)
+          pyplot.plot(sorted_maximums, [corr[e] for e in sorted_maximums], "o")
+
+      if debug:
+        pyplot.show()
+
+      return int(np.mean(lengths))
+
+    pulse_list = list(ListPWMPulseResponses(pwm_frames))
+    if len(pulse_list) < 4:
+      _log.info("No PWM artifact compensation needed.")
+      return
+
+    pulse = EstimateAveragePulseResponse(pulse_list)
+    length = EstimatePulseLength(pulse, pwm_frames)
+    _log.info("Detected PWM with pulse length of %d", length)
+
+    pulse = np.copy(pulse)
+    pulse.resize(length)
+    self.pwm_pulse = pulse
+    self.pwm_profile = np.tile(pulse, 10)
+    self.pwm_pulse_length = length
 
   def _GetFeaturesImage(self, camera_space_frame):
     """Trims image to area not showing the screen nor robot arm.
@@ -222,14 +356,39 @@
       DebugView(debug=delta, title="Stabilization Debug")
     return stable
 
-  def NormalizeFrame(self, screen_space_frame):
+  def NormalizeFrame(self, screen_space_frame, pwm_compensation=False,
+                     debug=False):
     """Normalizes color on a frame in screen space.
 
     :param np.ndarray screen_space_frame
     :returns np.ndarray
     """
-    return Filter.Truncate((screen_space_frame - self.black_reference) /
-                           (self.white_reference - self.black_reference))
+    normalized = Filter.Truncate((screen_space_frame - self.black_reference) /
+                                 (self.white_reference - self.black_reference))
+
+    if not pwm_compensation or self.pwm_profile is None:
+      return normalized
+
+    delta = 1.0 - normalized
+
+    mask = delta < 0.8
+    profile = np.sum(delta * mask, 0) / np.sum(mask, 0)
+
+    shift = nputil.EstimateShift(profile, self.pwm_profile, "mae",
+                                 max_shift=self.pwm_pulse_length, debug=debug)
+    pwm_profile = 1.0 - nputil.AlignArrays(profile, self.pwm_profile, shift,
+                                           mode="wrap")
+
+    # stretch profile into picture
+    stretched = np.tile(pwm_profile, (screen_space_frame.shape[0], 1))
+    stretched[stretched < 0.1] = 0.1 # prevent division by 0
+
+    compensated = Filter.Truncate(normalized / stretched)
+
+    if debug:
+      DebugView(orig=screen_space_frame, pwm=stretched, compensated=compensated)
+
+    return compensated
 
   def Validate(self):
     """Runs some checks on the calibration.
@@ -334,7 +493,7 @@
     return max(shapes, key=lambda s: s.area)
 
   def __getnewargs__(self):
-    return (self.black_frame, self.white_frame)
+    return (self.black_frame, self.white_frame, self.pwm_pulse)
 
   def __getstate__(self):
     return {}
diff --git a/optofidelity/optofidelity/videoproc/shape.py b/optofidelity/optofidelity/videoproc/shape.py
index d247486..e379caa 100644
--- a/optofidelity/optofidelity/videoproc/shape.py
+++ b/optofidelity/optofidelity/videoproc/shape.py
@@ -191,6 +191,9 @@
       self._contour = self._contour.astype(np.bool)
     return self._contour
 
+  def CalculateProfile(self, image):
+    return np.sum(image * self.mask, 0) / np.sum(self.mask, 0)
+
   @classmethod
   def FromRectangle(cls, array_shape, left=None, right=None, top=None,
                     bottom=None):
diff --git a/optofidelity/tests/detection/test_data/calibration_1.avi b/optofidelity/tests/detection/test_data/calibration_1.avi
new file mode 100644
index 0000000..5df5dcf
--- /dev/null
+++ b/optofidelity/tests/detection/test_data/calibration_1.avi
Binary files differ
diff --git a/optofidelity/tests/detection/test_data/pwm_calibration_2.avi b/optofidelity/tests/detection/test_data/pwm_calibration_2.avi
new file mode 100644
index 0000000..d3ce297
--- /dev/null
+++ b/optofidelity/tests/detection/test_data/pwm_calibration_2.avi
Binary files differ
diff --git a/optofidelity/tests/detection/test_screen_calibration.py b/optofidelity/tests/detection/test_screen_calibration.py
index f01a8c7..97e3829 100644
--- a/optofidelity/tests/detection/test_screen_calibration.py
+++ b/optofidelity/tests/detection/test_screen_calibration.py
@@ -6,10 +6,10 @@
 import cPickle as pickle
 
 import numpy as np
-
+import cv2
 from optofidelity.detection.screen_calibration import ScreenCalibration
 from optofidelity.videoproc import DebugView
-
+from tests.config import CONFIG
 from . import test_data
 
 
@@ -36,6 +36,11 @@
       calibration = ScreenCalibration.FromScreenFlashVideo(video_reader)
       self.assertCalibrationIsConsistent(calibration)
 
+  def testLowBrightnessCalibration(self):
+    with test_data.LoadVideo("calibration_1.avi") as video_reader:
+      calibration = ScreenCalibration.FromScreenFlashVideo(video_reader)
+      self.assertCalibrationIsConsistent(calibration)
+
   def testCompactPickling(self):
     before = self.createCalibration()
     pickled = pickle.dumps(before, protocol=2)
@@ -60,15 +65,33 @@
     self.assertTrue(np.allclose(black_normalized, 0))
     self.assertTrue(np.allclose(white_normalized, 1))
 
-  def testPWMCalibration(self):
-    calib = ScreenCalibration.FromScreenFlashVideo(
-        test_data.LoadVideo("pwm_calibration.avi"))
-    self.assertCalibrationIsConsistent(calib)
+  def testPWMCalibration1(self):
+    video = test_data.LoadVideo("pwm_calibration.avi")
+    calibration = ScreenCalibration.FromScreenFlashVideo(video)
 
-  def testPWMCompensation(self):
-    calib = ScreenCalibration(test_data.LoadImage("pwm_calib_black.png"),
-                              test_data.LoadImage("pwm_calib_white.png"))
-    self.assertLess(np.std(calib.white_reference), 0.01)
+    self.verifyPWMCalibration(video, calibration)
+
+  def testPWMCalibration2(self):
+    video = test_data.LoadVideo("pwm_calibration_2.avi")
+    calibration = ScreenCalibration.FromScreenFlashVideo(video)
+    self.verifyPWMCalibration(video, calibration)
+
+  def verifyPWMCalibration(self, video, calibration):
+    self.assertIsNotNone(calibration.pwm_profile)
+    if not CONFIG.get("user_interaction"):
+      return
+
+    screen_space = calibration.CameraToScreenSpace(video.FrameAt(127))
+    normalized = calibration.NormalizeFrame(screen_space, debug=True,
+                                            pwm_compensation=True)
+
+    for i, frame in video.Frames():
+      screen_space = calibration.CameraToScreenSpace(frame)
+      normalized = calibration.NormalizeFrame(screen_space,
+                                              pwm_compensation=True)
+      print i
+      cv2.imshow("PWM Calibration", normalized)
+      cv2.waitKey()
 
   def testStabilization(self):
     self.skipTest("Manual verification required.")