Define the RequiredContentAdapter interface and incorporate it into
FeedProtocolAdapter.

PiperOrigin-RevId: 250906593
Change-Id: I5f67b3d07b63a2a49580849ea18cf3f93865029d
diff --git a/src/main/java/com/google/android/libraries/feed/api/client/scope/ProcessScopeBuilder.java b/src/main/java/com/google/android/libraries/feed/api/client/scope/ProcessScopeBuilder.java
index 522ce70..6a0f387 100644
--- a/src/main/java/com/google/android/libraries/feed/api/client/scope/ProcessScopeBuilder.java
+++ b/src/main/java/com/google/android/libraries/feed/api/client/scope/ProcessScopeBuilder.java
@@ -64,6 +64,7 @@
 import com.google.android.libraries.feed.hostimpl.storage.InMemoryContentStorage;
 import com.google.android.libraries.feed.hostimpl.storage.InMemoryJournalStorage;
 import java.util.ArrayList;
+import java.util.Collections;
 import java.util.concurrent.Executor;
 
 /** Creates an instance of {@link ProcessScope} */
@@ -219,7 +220,7 @@
     FeedAppLifecycleListener lifecycleListener = new FeedAppLifecycleListener(threadUtils);
     lifecycleListener.registerObserver(store);
 
-    ProtocolAdapter protocolAdapter = new FeedProtocolAdapter(timingUtils);
+    ProtocolAdapter protocolAdapter = new FeedProtocolAdapter(Collections.emptyList(), timingUtils);
     ActionReader actionReader =
         new FeedActionReader(store, clock, protocolAdapter, taskQueue, configuration);
     FeedRequestManager feedRequestManager =
diff --git a/src/main/java/com/google/android/libraries/feed/api/internal/protocoladapter/RequiredContentAdapter.java b/src/main/java/com/google/android/libraries/feed/api/internal/protocoladapter/RequiredContentAdapter.java
new file mode 100644
index 0000000..4f938cd
--- /dev/null
+++ b/src/main/java/com/google/android/libraries/feed/api/internal/protocoladapter/RequiredContentAdapter.java
@@ -0,0 +1,27 @@
+// Copyright 2019 The Feed Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package com.google.android.libraries.feed.api.internal.protocoladapter;
+
+import com.google.search.now.wire.feed.ContentIdProto.ContentId;
+import com.google.search.now.wire.feed.DataOperationProto.DataOperation;
+import java.util.List;
+
+/**
+ * Interface that creates an extension point where implementers can indicate that a DataOperation is
+ * dependent on other content.
+ */
+public interface RequiredContentAdapter {
+  List<ContentId> determineRequiredContentIds(DataOperation dataOperation);
+}
diff --git a/src/main/java/com/google/android/libraries/feed/common/testing/InfraIntegrationScope.java b/src/main/java/com/google/android/libraries/feed/common/testing/InfraIntegrationScope.java
index b5e987a..24c4e55 100644
--- a/src/main/java/com/google/android/libraries/feed/common/testing/InfraIntegrationScope.java
+++ b/src/main/java/com/google/android/libraries/feed/common/testing/InfraIntegrationScope.java
@@ -48,6 +48,7 @@
 import com.google.android.libraries.feed.testing.requestmanager.FakeFeedRequestManager;
 import com.google.protobuf.GeneratedMessageLite.GeneratedExtension;
 import java.util.ArrayList;
+import java.util.Collections;
 import java.util.List;
 import java.util.concurrent.Executor;
 import java.util.concurrent.TimeUnit;
@@ -116,7 +117,7 @@
             fakeClock,
             fakeBasicLoggingApi,
             fakeMainThreadRunner);
-    feedProtocolAdapter = new FeedProtocolAdapter(timingUtils);
+    feedProtocolAdapter = new FeedProtocolAdapter(Collections.emptyList(), timingUtils);
     fakeFeedRequestManager =
         new FakeFeedRequestManager(
             fakeThreadUtils, fakeMainThreadRunner, feedProtocolAdapter, taskQueue);
diff --git a/src/main/java/com/google/android/libraries/feed/feedprotocoladapter/FeedProtocolAdapter.java b/src/main/java/com/google/android/libraries/feed/feedprotocoladapter/FeedProtocolAdapter.java
index 5c6fcce..2288b19 100644
--- a/src/main/java/com/google/android/libraries/feed/feedprotocoladapter/FeedProtocolAdapter.java
+++ b/src/main/java/com/google/android/libraries/feed/feedprotocoladapter/FeedProtocolAdapter.java
@@ -15,6 +15,7 @@
 package com.google.android.libraries.feed.feedprotocoladapter;
 
 import com.google.android.libraries.feed.api.internal.protocoladapter.ProtocolAdapter;
+import com.google.android.libraries.feed.api.internal.protocoladapter.RequiredContentAdapter;
 import com.google.android.libraries.feed.common.Result;
 import com.google.android.libraries.feed.common.Validators;
 import com.google.android.libraries.feed.common.logging.Dumpable;
@@ -45,7 +46,9 @@
 import com.google.search.now.wire.feed.ResponseProto.Response;
 import com.google.search.now.wire.feed.TokenProto.Token;
 import java.util.ArrayList;
+import java.util.HashSet;
 import java.util.List;
+import java.util.Set;
 
 /** A ProtocolAdapter which converts from the wire protocol to the internal protocol. */
 public final class FeedProtocolAdapter implements ProtocolAdapter, Dumpable {
@@ -54,14 +57,17 @@
   static final String CONTENT_ID_DELIMITER = "::";
 
   private final List<DataOperationTransformer> dataOperationTransformers;
+  private final List<RequiredContentAdapter> requiredContentAdapters;
   private final TimingUtils timingUtils;
 
   // Operation counts for #dump(Dumpable)
   private int responseHandlingCount = 0;
   private int convertContentIdCount = 0;
 
-  public FeedProtocolAdapter(TimingUtils timingUtils) {
+  public FeedProtocolAdapter(
+      List<RequiredContentAdapter> requiredContentAdapters, TimingUtils timingUtils) {
     this.timingUtils = timingUtils;
+    this.requiredContentAdapters = requiredContentAdapters;
     dataOperationTransformers = new ArrayList<>(2);
     dataOperationTransformers.add(new FeatureDataOperationTransformer());
     dataOperationTransformers.add(new ContentDataOperationTransformer());
@@ -119,10 +125,13 @@
       List<DataOperation> dataOperations, FeedResponseMetadata responseMetadata) {
     ElapsedTimeTracker totalTimeTracker = timingUtils.getElapsedTimeTracker(TAG);
     List<StreamDataOperation> streamDataOperations = new ArrayList<>();
+    Set<ContentId> requiredContentIds = new HashSet<>();
     for (DataOperation operation : dataOperations) {
-      // The operations defined in stream_data.proto and data_operation.proto have the same
-      // integer value
-      Operation streamOperation = Operation.forNumber(operation.getOperation().getNumber());
+      for (RequiredContentAdapter adapter : requiredContentAdapters) {
+        requiredContentIds.addAll(adapter.determineRequiredContentIds(operation));
+      }
+
+      Operation streamOperation = operationToStreamOperation(operation.getOperation());
       String contentId;
       if (streamOperation == Operation.CLEAR_ALL) {
         streamDataOperations.add(createDataOperation(Operation.CLEAR_ALL, null, null).build());
@@ -179,6 +188,17 @@
                 .build());
       }
     }
+
+    for (ContentId requiredContentId : requiredContentIds) {
+      streamDataOperations.add(
+          StreamDataOperation.newBuilder()
+              .setStreamStructure(
+                  StreamStructure.newBuilder()
+                      .setOperation(Operation.REQUIRED_CONTENT)
+                      .setContentId(createContentId(requiredContentId)))
+              .build());
+    }
+
     totalTimeTracker.stop("task", "convertWireProtocol", "mutations", dataOperations.size());
     return Result.success(streamDataOperations);
   }
@@ -188,7 +208,7 @@
       FeedResponseMetadata feedResponseMetadata,
       String contentId,
       List<StreamDataOperation> streamDataOperations) {
-    Operation streamOperation = Operation.forNumber(operation.getOperation().getNumber());
+    Operation streamOperation = operationToStreamOperation(operation.getOperation());
     String parentId = null;
     if (operation.getFeature().hasParentId()) {
       parentId = createContentId(operation.getFeature().getParentId());
@@ -328,4 +348,19 @@
     dumper.forKey("responseHandlingCount").value(responseHandlingCount);
     dumper.forKey("convertContentIdCount").value(convertContentIdCount).compactPrevious();
   }
+
+  private static Operation operationToStreamOperation(DataOperation.Operation operation) {
+    switch (operation) {
+      case UNKNOWN_OPERATION:
+        return Operation.UNKNOWN;
+      case CLEAR_ALL:
+        return Operation.CLEAR_ALL;
+      case UPDATE_OR_APPEND:
+        return Operation.UPDATE_OR_APPEND;
+      case REMOVE:
+        return Operation.REMOVE;
+    }
+
+    return Operation.UNKNOWN;
+  }
 }
diff --git a/src/test/java/com/google/android/libraries/feed/feedprotocoladapter/BUILD b/src/test/java/com/google/android/libraries/feed/feedprotocoladapter/BUILD
index 8b90994..78756b3 100644
--- a/src/test/java/com/google/android/libraries/feed/feedprotocoladapter/BUILD
+++ b/src/test/java/com/google/android/libraries/feed/feedprotocoladapter/BUILD
@@ -10,6 +10,7 @@
     aapt_version = "aapt2",
     manifest_values = DEFAULT_ANDROID_LOCAL_TEST_MANIFEST,
     deps = [
+        "//src/main/java/com/google/android/libraries/feed/api/internal/protocoladapter",
         "//src/main/java/com/google/android/libraries/feed/common",
         "//src/main/java/com/google/android/libraries/feed/common/testing",
         "//src/main/java/com/google/android/libraries/feed/common/time",
@@ -18,6 +19,7 @@
         "//src/main/proto/search/now/wire/feed:feed_java_proto_lite",
         "//third_party:robolectric",
         "@com_google_protobuf_javalite//:protobuf_java_lite",
+        "@maven//:com_google_guava_guava",
         "@maven//:com_google_truth_truth",
         "@maven//:org_mockito_mockito_core",
         "@robolectric//bazel:android-all",
diff --git a/src/test/java/com/google/android/libraries/feed/feedprotocoladapter/FeedProtocolAdapterTest.java b/src/test/java/com/google/android/libraries/feed/feedprotocoladapter/FeedProtocolAdapterTest.java
index 3e85617..ce6ad96 100644
--- a/src/test/java/com/google/android/libraries/feed/feedprotocoladapter/FeedProtocolAdapterTest.java
+++ b/src/test/java/com/google/android/libraries/feed/feedprotocoladapter/FeedProtocolAdapterTest.java
@@ -15,14 +15,22 @@
 package com.google.android.libraries.feed.feedprotocoladapter;
 
 import static com.google.common.truth.Truth.assertThat;
+import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
 import static org.mockito.MockitoAnnotations.initMocks;
 
+import com.google.android.libraries.feed.api.internal.protocoladapter.RequiredContentAdapter;
 import com.google.android.libraries.feed.common.Result;
 import com.google.android.libraries.feed.common.testing.ResponseBuilder;
 import com.google.android.libraries.feed.common.time.TimingUtils;
+import com.google.common.collect.ImmutableList;
 import com.google.protobuf.ByteString;
 import com.google.search.now.feed.client.StreamDataProto.StreamDataOperation;
+import com.google.search.now.feed.client.StreamDataProto.StreamStructure;
 import com.google.search.now.wire.feed.ContentIdProto.ContentId;
+import com.google.search.now.wire.feed.DataOperationProto.DataOperation;
 import com.google.search.now.wire.feed.OpaqueActionDataProto.OpaqueActionData;
 import com.google.search.now.wire.feed.ResponseProto.Response;
 import java.nio.charset.Charset;
@@ -30,6 +38,7 @@
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
+import org.mockito.Mock;
 import org.robolectric.RobolectricTestRunner;
 
 /** Tests of the {@link FeedProtocolAdapter} class. */
@@ -37,6 +46,10 @@
 public class FeedProtocolAdapterTest {
 
   private final TimingUtils timingUtils = new TimingUtils();
+  private final FeedProtocolAdapter protocolAdapter =
+      new FeedProtocolAdapter(ImmutableList.of(), timingUtils);
+
+  @Mock private RequiredContentAdapter adapter;
   private ResponseBuilder responseBuilder;
 
   @Before
@@ -47,7 +60,6 @@
 
   @Test
   public void testConvertContentId() {
-    FeedProtocolAdapter protocolAdapter = new FeedProtocolAdapter(timingUtils);
     ContentId contentId = ResponseBuilder.createFeatureContentId(13);
     String streamContentId = protocolAdapter.getStreamContentId(contentId);
     assertThat(streamContentId).isNotNull();
@@ -58,7 +70,6 @@
 
   @Test
   public void testConvertContentId_malformed_notThreeParts() {
-    FeedProtocolAdapter protocolAdapter = new FeedProtocolAdapter(timingUtils);
     String streamContentId = "test" + FeedProtocolAdapter.CONTENT_ID_DELIMITER + "break";
     Result<ContentId> contentIdResult = protocolAdapter.getWireContentId(streamContentId);
     assertThat(contentIdResult.isSuccessful()).isFalse();
@@ -66,7 +77,6 @@
 
   @Test
   public void testConvertContentId_malformed_nonNumericId() {
-    FeedProtocolAdapter protocolAdapter = new FeedProtocolAdapter(timingUtils);
     String streamContentId =
         "test"
             + FeedProtocolAdapter.CONTENT_ID_DELIMITER
@@ -79,7 +89,6 @@
 
   @Test
   public void testConvertContentId_roundTrip() {
-    FeedProtocolAdapter protocolAdapter = new FeedProtocolAdapter(timingUtils);
     ContentId contentId = ResponseBuilder.createFeatureContentId(13);
     String streamContentId = protocolAdapter.getStreamContentId(contentId);
     Result<ContentId> contentIdResult = protocolAdapter.getWireContentId(streamContentId);
@@ -90,7 +99,6 @@
 
   @Test
   public void testConvertContentId_roundTrip_partialContentId() {
-    FeedProtocolAdapter protocolAdapter = new FeedProtocolAdapter(timingUtils);
     ContentId contentId = ContentId.newBuilder().setId(13).build();
     String streamContentId = protocolAdapter.getStreamContentId(contentId);
     Result<ContentId> contentIdResult = protocolAdapter.getWireContentId(streamContentId);
@@ -101,7 +109,6 @@
 
   @Test
   public void testSimpleResponse_clear() {
-    FeedProtocolAdapter protocolAdapter = new FeedProtocolAdapter(timingUtils);
     Response response = responseBuilder.addClearOperation().build();
     Result<List<StreamDataOperation>> results = protocolAdapter.createModel(response);
     assertThat(results.isSuccessful()).isTrue();
@@ -110,7 +117,6 @@
 
   @Test
   public void testSimpleResponse_feature() {
-    FeedProtocolAdapter protocolAdapter = new FeedProtocolAdapter(timingUtils);
     Response response = responseBuilder.addRootFeature().build();
 
     Result<List<StreamDataOperation>> results = protocolAdapter.createModel(response);
@@ -128,7 +134,6 @@
 
   @Test
   public void testSimpleResponse_feature_semanticProperties() {
-    FeedProtocolAdapter protocolAdapter = new FeedProtocolAdapter(timingUtils);
     ContentId contentId = ResponseBuilder.createFeatureContentId(13);
     ByteString semanticData = ByteString.copyFromUtf8("helloWorld");
     Response response =
@@ -145,7 +150,6 @@
 
   @Test
   public void testSimpleResponse_feature_actionProperties() {
-    FeedProtocolAdapter protocolAdapter = new FeedProtocolAdapter(timingUtils);
     ContentId contentId = ResponseBuilder.createFeatureContentId(13);
     OpaqueActionData actionData = OpaqueActionData.getDefaultInstance();
     Response response = new ResponseBuilder().addCardWithActionData(contentId, actionData).build();
@@ -170,7 +174,6 @@
             .addClusterFeature(clusterId, rootId)
             .addCard(cardId, clusterId)
             .build();
-    FeedProtocolAdapter protocolAdapter = new FeedProtocolAdapter(timingUtils);
 
     Result<List<StreamDataOperation>> results = protocolAdapter.createModel(response);
     assertThat(results.isSuccessful()).isTrue();
@@ -185,7 +188,6 @@
 
   @Test
   public void testResponse_remove() {
-    FeedProtocolAdapter protocolAdapter = new FeedProtocolAdapter(timingUtils);
     Response response =
         responseBuilder
             .removeFeature(ContentId.getDefaultInstance(), ContentId.getDefaultInstance())
@@ -197,7 +199,6 @@
 
   @Test
   public void testPietSharedState() {
-    FeedProtocolAdapter protocolAdapter = new FeedProtocolAdapter(timingUtils);
     Response response = responseBuilder.addPietSharedState().build();
     Result<List<StreamDataOperation>> results = protocolAdapter.createModel(response);
     assertThat(results.isSuccessful()).isTrue();
@@ -211,7 +212,6 @@
 
   @Test
   public void testContinuationToken_nextPageToken() {
-    FeedProtocolAdapter protocolAdapter = new FeedProtocolAdapter(timingUtils);
     ByteString tokenForMutation = ByteString.copyFrom("token", Charset.defaultCharset());
     Response response = responseBuilder.addStreamToken(1, tokenForMutation).build();
 
@@ -222,4 +222,43 @@
     assertThat(sdo.hasStreamPayload()).isTrue();
     assertThat(sdo.getStreamPayload().hasStreamToken()).isTrue();
   }
+
+  @Test
+  public void testRequiredContentAdapter() {
+    when(adapter.determineRequiredContentIds(any(DataOperation.class)))
+        .thenReturn(ImmutableList.of(ContentId.newBuilder().setId(1).build()))
+        .thenReturn(ImmutableList.of(ContentId.newBuilder().setId(2).build()))
+        .thenReturn(
+            ImmutableList.of(
+                ContentId.newBuilder().setId(1).build(),
+                ContentId.newBuilder().setId(2).build(),
+                ContentId.newBuilder().setId(3).build()));
+    Response response =
+        responseBuilder
+            .addRootFeature(ContentId.getDefaultInstance())
+            .addClusterFeature(ContentId.getDefaultInstance(), ContentId.getDefaultInstance())
+            .addCard(ContentId.getDefaultInstance(), ContentId.getDefaultInstance())
+            .build();
+
+    FeedProtocolAdapter protocolAdapter =
+        new FeedProtocolAdapter(ImmutableList.of(adapter), timingUtils);
+    Result<List<StreamDataOperation>> result = protocolAdapter.createModel(response);
+
+    verify(adapter, times(4)).determineRequiredContentIds(any(DataOperation.class));
+    assertThat(result.isSuccessful()).isTrue();
+    List<StreamDataOperation> operations = result.getValue();
+    assertThat(operations).hasSize(7);
+    assertThat(operations.get(4).getStreamStructure().getOperation())
+        .isEqualTo(StreamStructure.Operation.REQUIRED_CONTENT);
+    assertThat(operations.get(5).getStreamStructure().getOperation())
+        .isEqualTo(StreamStructure.Operation.REQUIRED_CONTENT);
+    assertThat(operations.get(6).getStreamStructure().getOperation())
+        .isEqualTo(StreamStructure.Operation.REQUIRED_CONTENT);
+    assertThat(
+            ImmutableList.of(
+                operations.get(4).getStreamStructure().getContentId(),
+                operations.get(5).getStreamStructure().getContentId(),
+                operations.get(6).getStreamStructure().getContentId()))
+        .containsExactly("::::1", "::::2", "::::3");
+  }
 }