TaskQueue should cancel timeout runnables when running tasks. Also simplify the
logic around race conditions in TimeoutTaskWrapper.

PiperOrigin-RevId: 243159468
Change-Id: Ic2955810cc4059c1f24556514c49b4d4e9029a74
diff --git a/src/main/java/com/google/android/libraries/feed/common/concurrent/TaskQueue.java b/src/main/java/com/google/android/libraries/feed/common/concurrent/TaskQueue.java
index a4788fc..a6447d6 100644
--- a/src/main/java/com/google/android/libraries/feed/common/concurrent/TaskQueue.java
+++ b/src/main/java/com/google/android/libraries/feed/common/concurrent/TaskQueue.java
@@ -580,9 +580,9 @@
    * not started before the timeout millis.
    */
   private final class TimeoutTaskWrapper extends TaskWrapper {
-    private final Runnable timeoutRunnable;
     private final AtomicBoolean started = new AtomicBoolean(false);
-    private final AtomicBoolean ranTimeoutTask = new AtomicBoolean(false);
+    private final Runnable timeoutRunnable;
+    /*@Nullable*/ private CancelableTask timeoutTask = null;
 
     TimeoutTaskWrapper(
         @Task int task, @TaskType int taskType, Runnable taskRunnable, Runnable timeoutRunnable) {
@@ -595,26 +595,38 @@
      * #timeoutRunnable} will run.
      */
     private TimeoutTaskWrapper startTimeout(long timeoutMillis) {
-      mainThreadRunner.executeWithDelay("taskTimeout", this::runTimeoutCallback, timeoutMillis);
+      timeoutTask =
+          mainThreadRunner.executeWithDelay("taskTimeout", this::runTimeoutCallback, timeoutMillis);
       return this;
     }
 
     @Override
     void runTask() {
-      started.set(true);
-      if (ranTimeoutTask.get()) {
-        Logger.w(TAG, " - We ran the TimeoutTask already");
+      // If the boolean is already set then runTimeoutCallback has run.
+      if (started.getAndSet(true)) {
+        Logger.w(TAG, " - runTimeoutCallback already ran [%s]", task);
+        executeNextTask();
         return;
       }
+
+      CancelableTask localTask = timeoutTask;
+      if (localTask != null) {
+        Logger.i(TAG, "Cancelling timeout [%s]", task);
+        localTask.cancel();
+        timeoutTask = null;
+      }
+
       super.runTask();
     }
 
     private void runTimeoutCallback() {
-      if (started.get()) {
+      // If the boolean is already set then runTask has run.
+      if (started.getAndSet(true)) {
+        Logger.w(TAG, " - runTask already ran [%s]", task);
         return;
       }
+
       Logger.w(TAG, "Execute Timeout [%s]: %s", taskType, task);
-      ranTimeoutTask.set(true);
       executor.execute(timeoutRunnable);
     }
   }
diff --git a/src/test/java/com/google/android/libraries/feed/common/concurrent/TaskQueueTest.java b/src/test/java/com/google/android/libraries/feed/common/concurrent/TaskQueueTest.java
index ffe2574..213fc16 100644
--- a/src/test/java/com/google/android/libraries/feed/common/concurrent/TaskQueueTest.java
+++ b/src/test/java/com/google/android/libraries/feed/common/concurrent/TaskQueueTest.java
@@ -49,6 +49,7 @@
           /* checkStarvation= */ true);
 
   private boolean delayedTaskHasRun = false;
+  private boolean delayedTaskHasTimedOut = false;
 
   @Before
   public void setUp() {
@@ -261,6 +262,47 @@
     assertThat(fakeBasicLoggingApi.lastTaskDelay).isEqualTo(delayTime);
   }
 
+  @Test
+  public void testTimeout_withoutTimeout() {
+    // Put the TaskQueue into a delaying state and schedule a task with a timeout.
+    taskQueue.initialize(this::noOp);
+    taskQueue.execute(Task.UNKNOWN, TaskType.HEAD_INVALIDATE, this::noOp);
+    taskQueue.execute(
+        Task.UNKNOWN,
+        TaskType.BACKGROUND,
+        this::delayedTask,
+        this::delayedTaskTimeout,
+        /*timeoutMillis= */ 10L);
+
+    fakeClock.advance(9L);
+    taskQueue.execute(Task.UNKNOWN, TaskType.HEAD_RESET, this::noOp);
+    assertThat(delayedTaskHasRun).isTrue();
+    assertThat(delayedTaskHasTimedOut).isFalse();
+
+    fakeClock.tick();
+    assertThat(delayedTaskHasTimedOut).isFalse();
+  }
+
+  @Test
+  public void testTimeout_withTimeout() {
+    // Put the TaskQueue into a delaying state and schedule a task with a timeout.
+    taskQueue.initialize(this::noOp);
+    taskQueue.execute(Task.UNKNOWN, TaskType.HEAD_INVALIDATE, this::noOp);
+    taskQueue.execute(
+        Task.UNKNOWN,
+        TaskType.BACKGROUND,
+        this::delayedTask,
+        this::delayedTaskTimeout,
+        /*timeoutMillis= */ 10L);
+
+    fakeClock.advance(10L);
+    assertThat(delayedTaskHasRun).isFalse();
+    assertThat(delayedTaskHasTimedOut).isTrue();
+
+    taskQueue.execute(Task.UNKNOWN, TaskType.HEAD_RESET, this::noOp);
+    assertThat(delayedTaskHasRun).isFalse();
+  }
+
   private void runAndAssertStarvationChecks() {
     int starvationTaskCount = 0;
     long startTimeMillis = fakeClock.currentTimeMillis();
@@ -290,6 +332,10 @@
     delayedTaskHasRun = true;
   }
 
+  private void delayedTaskTimeout() {
+    delayedTaskHasTimedOut = true;
+  }
+
   private void longRunningTask() {
     fakeClock.advance(LONG_RUNNING_TASK_TIME);
   }