Plumb schema version into persisted StreamSession data. The schema version for
a specific session is equal to the schema version of HEAD when the session is
created.

PiperOrigin-RevId: 251315760
Change-Id: I9a572968f4111f9f710c46395488e53b8b9c85ee
diff --git a/src/main/java/com/google/android/libraries/feed/feedsessionmanager/FeedSessionManagerImpl.java b/src/main/java/com/google/android/libraries/feed/feedsessionmanager/FeedSessionManagerImpl.java
index 0aa4956..2ca4afa 100644
--- a/src/main/java/com/google/android/libraries/feed/feedsessionmanager/FeedSessionManagerImpl.java
+++ b/src/main/java/com/google/android/libraries/feed/feedsessionmanager/FeedSessionManagerImpl.java
@@ -340,7 +340,8 @@
       modelProvider.raiseError(Validators.checkNotNull(noCardsError));
       String sessionId = streamSessionResult.getValue();
       session.setSessionId(sessionId);
-      sessionCache.putAttached(sessionId, clock.currentTimeMillis(), session);
+      sessionCache.putAttached(
+          sessionId, clock.currentTimeMillis(), sessionCache.getHead().getSchemaVersion(), session);
       synchronized (lock) {
         sessionsUnderConstruction.remove(session);
       }
@@ -374,7 +375,8 @@
         cachedBindings,
         shouldAppendOutstandingRequest,
         uiContext);
-    sessionCache.putAttached(sessionId, creationTimeMillis, session);
+    sessionCache.putAttached(
+        sessionId, creationTimeMillis, sessionCache.getHead().getSchemaVersion(), session);
     synchronized (lock) {
       sessionsUnderConstruction.remove(session);
     }
diff --git a/src/main/java/com/google/android/libraries/feed/feedsessionmanager/internal/HeadSessionImpl.java b/src/main/java/com/google/android/libraries/feed/feedsessionmanager/internal/HeadSessionImpl.java
index 1392a29..bee56f3 100644
--- a/src/main/java/com/google/android/libraries/feed/feedsessionmanager/internal/HeadSessionImpl.java
+++ b/src/main/java/com/google/android/libraries/feed/feedsessionmanager/internal/HeadSessionImpl.java
@@ -41,6 +41,8 @@
   private final SessionContentTracker sessionContentTracker =
       new SessionContentTracker(/* supportsClearAll= */ true);
 
+  private int schemaVersion = 0;
+
   // operation counts for the dumper
   private int updateCount = 0;
   private int storeMutationFailures = 0;
@@ -51,8 +53,9 @@
   }
 
   /** Initialize head from the stored stream structures. */
-  void initializeSession(List<StreamStructure> streamStructures) {
+  void initializeSession(List<StreamStructure> streamStructures, int schemaVersion) {
     Logger.i(TAG, "Initialize HEAD %s items", streamStructures.size());
+    this.schemaVersion = schemaVersion;
     sessionContentTracker.update(streamStructures);
   }
 
@@ -60,6 +63,10 @@
     sessionContentTracker.clear();
   }
 
+  public int getSchemaVersion() {
+    return schemaVersion;
+  }
+
   @Override
   public boolean invalidateOnResetHead() {
     // There won't be a ModelProvider to invalidate
diff --git a/src/main/java/com/google/android/libraries/feed/feedsessionmanager/internal/SessionCache.java b/src/main/java/com/google/android/libraries/feed/feedsessionmanager/internal/SessionCache.java
index 00937cb..10e2157 100644
--- a/src/main/java/com/google/android/libraries/feed/feedsessionmanager/internal/SessionCache.java
+++ b/src/main/java/com/google/android/libraries/feed/feedsessionmanager/internal/SessionCache.java
@@ -190,7 +190,8 @@
   }
 
   /** Add a {@link Session} to the SessionCache. */
-  public void putAttached(String sessionId, long creationTimeMillis, Session session) {
+  public void putAttached(
+      String sessionId, long creationTimeMillis, int schemaVersion, Session session) {
     Logger.d(TAG, "putAttached, sessionId=%s", sessionId);
 
     threadUtils.checkNotMainThread();
@@ -199,7 +200,10 @@
       attachedSessions.put(sessionId, session);
       sessionsMetadata.put(
           sessionId,
-          SessionMetadata.newBuilder().setCreationTimeMillis(creationTimeMillis).build());
+          SessionMetadata.newBuilder()
+              .setCreationTimeMillis(creationTimeMillis)
+              .setSchemaVersion(schemaVersion)
+              .build());
       Logger.d(
           TAG,
           "Sessions size: attached=%d, all=%d",
@@ -284,11 +288,14 @@
       Logger.w(TAG, "unable to get head stream structures");
       return false;
     }
-    head.initializeSession(results.getValue());
+
+    List<StreamSession> sessionList = getPersistedSessions();
+    int headSchemaVersion = getHeadSchemaVersion(sessionList);
+    head.initializeSession(results.getValue(), headSchemaVersion);
     initialized = true;
     headTimeTracker.stop("", "createHead");
 
-    initializePersistedSessions();
+    initializePersistedSessions(sessionList);
 
     // Ensure that SessionMetadata exists for HEAD.
     synchronized (lock) {
@@ -337,9 +344,11 @@
   void updateHeadLastAddedTimeMillis(long lastAddedTimeMillis) {
     threadUtils.checkNotMainThread();
     synchronized (lock) {
+      SessionMetadata metadata = sessionsMetadata.get(head.getSessionId());
+      SessionMetadata.Builder builder =
+          metadata == null ? SessionMetadata.newBuilder() : metadata.toBuilder();
       sessionsMetadata.put(
-          head.getSessionId(),
-          SessionMetadata.newBuilder().setLastAddedTimeMillis(lastAddedTimeMillis).build());
+          head.getSessionId(), builder.setLastAddedTimeMillis(lastAddedTimeMillis).build());
     }
     updatePersistedSessionsMetadata();
   }
@@ -395,9 +404,8 @@
     return unboundSessions.values();
   }
 
-  private void initializePersistedSessions() {
+  private void initializePersistedSessions(List<StreamSession> sessionList) {
     threadUtils.checkNotMainThread();
-    List<StreamSession> sessionList = getPersistedSessions();
 
     HeadSessionImpl headSession = Validators.checkNotNull(head);
     String headSessionId = headSession.getSessionId();
@@ -598,6 +606,16 @@
     return metadataBuilder.build();
   }
 
+  private int getHeadSchemaVersion(List<StreamSession> sessionList) {
+    for (StreamSession streamSession : sessionList) {
+      if (streamSession.getSessionId().equals(head.getSessionId())) {
+        return streamSession.getSessionMetadata().getSchemaVersion();
+      }
+    }
+
+    return 0;
+  }
+
   @Override
   public void dump(Dumper dumper) {
     dumper.title(TAG);
diff --git a/src/main/proto/com/google/android/libraries/feed/api/internal/proto/stream_data.proto b/src/main/proto/com/google/android/libraries/feed/api/internal/proto/stream_data.proto
index 25858ec..9b9d6f5 100644
--- a/src/main/proto/com/google/android/libraries/feed/api/internal/proto/stream_data.proto
+++ b/src/main/proto/com/google/android/libraries/feed/api/internal/proto/stream_data.proto
@@ -227,6 +227,9 @@
 
   // The time in milliseconds that this session was created.
   optional int64 creation_time_millis = 2;
+
+  // The schema used to create this session.
+  optional int32 schema_version = 3;
 }
 
 message StreamLocalAction {
diff --git a/src/test/java/com/google/android/libraries/feed/feedsessionmanager/BUILD b/src/test/java/com/google/android/libraries/feed/feedsessionmanager/BUILD
index cbef0f8..aba4e26 100644
--- a/src/test/java/com/google/android/libraries/feed/feedsessionmanager/BUILD
+++ b/src/test/java/com/google/android/libraries/feed/feedsessionmanager/BUILD
@@ -40,6 +40,7 @@
         "//third_party:robolectric",
         "@com_google_code_findbugs_jsr305//jar",
         "@com_google_protobuf_javalite//:protobuf_java_lite",
+        "@maven//:com_google_guava_guava",
         "@maven//:com_google_truth_truth",
         "@maven//:org_mockito_mockito_core",
         "@robolectric//bazel:android-all",
diff --git a/src/test/java/com/google/android/libraries/feed/feedsessionmanager/FeedSessionManagerImplTest.java b/src/test/java/com/google/android/libraries/feed/feedsessionmanager/FeedSessionManagerImplTest.java
index 0f04a4b..9e006a8 100644
--- a/src/test/java/com/google/android/libraries/feed/feedsessionmanager/FeedSessionManagerImplTest.java
+++ b/src/test/java/com/google/android/libraries/feed/feedsessionmanager/FeedSessionManagerImplTest.java
@@ -62,6 +62,8 @@
 import com.google.android.libraries.feed.testing.requestmanager.FakeActionUploadRequestManager;
 import com.google.android.libraries.feed.testing.requestmanager.FakeFeedRequestManager;
 import com.google.android.libraries.feed.testing.store.FakeStore;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
 import com.google.protobuf.ByteString;
 import com.google.search.now.feed.client.StreamDataProto.StreamDataOperation;
 import com.google.search.now.feed.client.StreamDataProto.StreamPayload;
@@ -81,12 +83,6 @@
 import com.google.search.now.wire.feed.PietSharedStateItemProto.PietSharedStateItem;
 import com.google.search.now.wire.feed.ResponseProto.Response;
 import java.nio.charset.Charset;
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Map;
-import java.util.Set;
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
@@ -161,17 +157,13 @@
     assertThat(sessionManager.initialized.get()).isFalse();
     sessionManager.initialize();
     assertThat(sessionManager.initialized.get()).isTrue();
-
-    Map<String, StreamSharedState> sharedStateCache = sessionManager.getSharedStateCacheForTest();
-    assertThat(sharedStateCache).hasSize(1);
+    assertThat(sessionManager.getSharedStateCacheForTest()).hasSize(1);
 
     SessionCache sessionCache = sessionManager.getSessionCacheForTest();
     Session head = sessionCache.getHead();
     assertThat(head).isInstanceOf(HeadSessionImpl.class);
     String itemKey = idGenerators.createFeatureContentId(0);
-    Set<String> content = head.getContentInSession();
-    assertThat(content).contains(itemKey);
-    assertThat(content).hasSize(1);
+    assertThat(head.getContentInSession()).containsExactly(itemKey);
   }
 
   // This is testing a condition similar to the one that caused [INTERNAL LINK].
@@ -220,9 +212,7 @@
     assertThat(sessionManager.initialized.get()).isFalse();
     sessionManager.initialize();
     assertThat(sessionManager.initialized.get()).isTrue();
-
-    Map<String, StreamSharedState> sharedStateCache = sessionManager.getSharedStateCacheForTest();
-    assertThat(sharedStateCache).hasSize(2);
+    assertThat(sessionManager.getSharedStateCacheForTest()).hasSize(2);
 
     StreamSharedState cachedSharedState1 = sessionManager.getSharedState(contentId1);
     StreamSharedState cachedSharedState2 = sessionManager.getSharedState(contentId2);
@@ -515,7 +505,7 @@
             .build();
 
     Consumer<Result<Model>> updateConsumer = sessionManager.getUpdateConsumer(EMPTY_MUTATION);
-    Result<Model> result = Result.success(Model.of(listOf(streamDataOperation)));
+    Result<Model> result = Result.success(Model.of(ImmutableList.of(streamDataOperation)));
     updateConsumer.accept(result);
 
     assertThat(fakeStore.getContentById(rootContentId))
@@ -537,7 +527,7 @@
             .build();
 
     Consumer<Result<Model>> updateConsumer = sessionManager.getUpdateConsumer(EMPTY_MUTATION);
-    Result<Model> result = Result.success(Model.of(listOf(streamDataOperation)));
+    Result<Model> result = Result.success(Model.of(ImmutableList.of(streamDataOperation)));
     updateConsumer.accept(result);
 
     assertThat(fakeStore.getContentById(rootContentId))
@@ -560,25 +550,21 @@
     int featureCount = 3;
     populateSession(sessionManager, featureCount, 1, true, sharedStateId);
 
-    Map<String, StreamSharedState> sharedStates = sessionManager.getSharedStateCacheForTest();
-    assertThat(sharedStates).hasSize(1);
+    assertThat(sessionManager.getSharedStateCacheForTest()).hasSize(1);
     SessionCache sessionCache = sessionManager.getSessionCacheForTest();
     assertThat(sessionCache.getAttachedSessions()).isEmpty();
     Session session = sessionCache.getHead();
     assertThat(session).isNotNull();
-    Set<String> contentInSession = session.getContentInSession();
-    assertThat(contentInSession).hasSize(featureCount + 1);
+    assertThat(session.getContentInSession()).hasSize(featureCount + 1);
 
     fakeThreadUtils.enforceMainThread(false);
     sessionManager.onSwitchToEphemeralMode();
 
-    sharedStates = sessionManager.getSharedStateCacheForTest();
-    assertThat(sharedStates).hasSize(0);
+    assertThat(sessionManager.getSharedStateCacheForTest()).isEmpty();
     assertThat(sessionCache.getAttachedSessions()).isEmpty();
     session = sessionCache.getHead();
     assertThat(session).isNotNull();
-    contentInSession = session.getContentInSession();
-    assertThat(contentInSession).hasSize(0);
+    assertThat(session.getContentInSession()).isEmpty();
   }
 
   @Test
@@ -624,8 +610,8 @@
     FeedSessionManagerImpl sessionManager =
         getInitializedSessionManager(
             new Configuration.Builder().put(ConfigKey.UNDOABLE_ACTIONS_ENABLED, true).build());
-    HashSet<StreamUploadableAction> actionSet = new HashSet<>();
-    actionSet.add(StreamUploadableAction.getDefaultInstance());
+    ImmutableSet<StreamUploadableAction> actionSet =
+        ImmutableSet.of(StreamUploadableAction.getDefaultInstance());
     ConsistencyToken token =
         ConsistencyToken.newBuilder().setToken(ByteString.copyFrom(new byte[] {0x1, 0xf})).build();
     fakeThreadUtils.enforceMainThread(false);
@@ -830,10 +816,4 @@
             appLifecycleListener)
         .create();
   }
-
-  private static <T> List<T> listOf(T... items) {
-    ArrayList<T> result = new ArrayList<>(items.length);
-    Collections.addAll(result, items);
-    return result;
-  }
 }
diff --git a/src/test/java/com/google/android/libraries/feed/feedsessionmanager/internal/BUILD b/src/test/java/com/google/android/libraries/feed/feedsessionmanager/internal/BUILD
index 4e0bf67..00472e9 100644
--- a/src/test/java/com/google/android/libraries/feed/feedsessionmanager/internal/BUILD
+++ b/src/test/java/com/google/android/libraries/feed/feedsessionmanager/internal/BUILD
@@ -40,6 +40,7 @@
         "//src/main/proto/com/google/android/libraries/feed/api/internal/proto:client_feed_java_proto_lite",
         "//third_party:robolectric",
         "@com_google_protobuf_javalite//:protobuf_java_lite",
+        "@maven//:com_google_guava_guava",
         "@maven//:com_google_truth_truth",
         "@robolectric//bazel:android-all",
     ],
diff --git a/src/test/java/com/google/android/libraries/feed/feedsessionmanager/internal/HeadSessionImplTest.java b/src/test/java/com/google/android/libraries/feed/feedsessionmanager/internal/HeadSessionImplTest.java
index e78cb81..dbbef79 100644
--- a/src/test/java/com/google/android/libraries/feed/feedsessionmanager/internal/HeadSessionImplTest.java
+++ b/src/test/java/com/google/android/libraries/feed/feedsessionmanager/internal/HeadSessionImplTest.java
@@ -25,6 +25,7 @@
 import com.google.android.libraries.feed.common.time.TimingUtils;
 import com.google.android.libraries.feed.common.time.testing.FakeClock;
 import com.google.android.libraries.feed.testing.store.FakeStore;
+import com.google.common.collect.ImmutableList;
 import com.google.search.now.feed.client.StreamDataProto.StreamStructure;
 import com.google.search.now.feed.client.StreamDataProto.StreamToken;
 import java.util.List;
@@ -258,6 +259,13 @@
     assertThat(getContentInSession()).hasSize(1);
   }
 
+  @Test
+  public void testInitializeSession_schemaVersion() {
+    int schemaVersion = 3;
+    headSession.initializeSession(ImmutableList.of(), schemaVersion);
+    assertThat(headSession.getSchemaVersion()).isEqualTo(schemaVersion);
+  }
+
   private void addFeatures(InternalProtocolBuilder protocolBuilder, int featureCnt, int startId) {
     for (int i = 0; i < featureCnt; i++) {
       protocolBuilder.addFeature(
@@ -269,7 +277,8 @@
   /** Re-read the session from disk and return the set of content. */
   private Set<String> getContentInSession() {
     HeadSessionImpl headSession = new HeadSessionImpl(fakeStore, timingUtils);
-    headSession.initializeSession(fakeStore.getStreamStructures(Store.HEAD_SESSION_ID).getValue());
+    headSession.initializeSession(
+        fakeStore.getStreamStructures(Store.HEAD_SESSION_ID).getValue(), /* schemaVersion= */ 0);
     return headSession.getContentInSession();
   }
 }
diff --git a/src/test/java/com/google/android/libraries/feed/feedsessionmanager/internal/SessionCacheTest.java b/src/test/java/com/google/android/libraries/feed/feedsessionmanager/internal/SessionCacheTest.java
index a82de1c..4ef662c 100644
--- a/src/test/java/com/google/android/libraries/feed/feedsessionmanager/internal/SessionCacheTest.java
+++ b/src/test/java/com/google/android/libraries/feed/feedsessionmanager/internal/SessionCacheTest.java
@@ -47,6 +47,7 @@
 @RunWith(RobolectricTestRunner.class)
 public class SessionCacheTest {
   private static final long DEFAULT_LIFETIME_MS = 10;
+  private static final int SCHEMA_VERSION = 4;
 
   private final Configuration configuration = new Configuration.Builder().build();
   private final ContentIdGenerators idGenerators = new ContentIdGenerators();
@@ -73,7 +74,16 @@
 
   @Test
   public void testInitialization() {
+    int schemaVersion = 3;
     populateHead();
+    mockStreamSessions(
+        StreamSessions.newBuilder()
+            .addStreamSession(
+                StreamSession.newBuilder()
+                    .setSessionId(Store.HEAD_SESSION_ID)
+                    .setSessionMetadata(
+                        SessionMetadata.newBuilder().setSchemaVersion(schemaVersion)))
+            .build());
     assertThat(sessionCache.getAttachedSessions()).isEmpty();
     assertThat(sessionCache.isHeadInitialized()).isFalse();
     assertThat(sessionCache.getHead()).isNotNull();
@@ -85,6 +95,7 @@
     assertThat(sessionCache.getAttachedSessions()).isEmpty();
     assertThat(sessionCache.getHead()).isNotNull();
     assertThat(sessionCache.getHead().isHeadEmpty()).isFalse();
+    assertThat(sessionCache.getHead().getSchemaVersion()).isEqualTo(schemaVersion);
   }
 
   @Test
@@ -101,13 +112,15 @@
   public void testPut_persisted() {
     sessionCache.initialize();
     Session session = populateSession(1, 2);
-    sessionCache.putAttached(session.getSessionId(), /* creationTimeMillis= */ 0L, session);
+    sessionCache.putAttached(
+        session.getSessionId(), /* creationTimeMillis= */ 0L, SCHEMA_VERSION, session);
 
     Session ret = sessionCache.getAttached(session.getSessionId());
     assertThat(ret).isEqualTo(session);
 
     session = populateSession(2, 2);
-    sessionCache.putAttached(session.getSessionId(), /* creationTimeMillis= */ 0L, session);
+    sessionCache.putAttached(
+        session.getSessionId(), /* creationTimeMillis= */ 0L, SCHEMA_VERSION, session);
 
     List<StreamSession> streamSessionList = sessionCache.getPersistedSessions();
     assertThat(streamSessionList).hasSize(3);
@@ -117,7 +130,8 @@
   public void testRemove() {
     sessionCache.initialize();
     Session session = populateSession(1, 2);
-    sessionCache.putAttached(session.getSessionId(), /* creationTimeMillis= */ 0L, session);
+    sessionCache.putAttached(
+        session.getSessionId(), /* creationTimeMillis= */ 0L, SCHEMA_VERSION, session);
 
     List<StreamSession> streamSessionList = sessionCache.getPersistedSessions();
     assertThat(streamSessionList).hasSize(2);
@@ -168,14 +182,14 @@
 
     Session s1 = populateSession(1, 2, /* commitToStore= */ true);
     String s1Id = s1.getSessionId();
-    sessionCache.putAttached(s1Id, 1L, s1);
+    sessionCache.putAttached(s1Id, 1L, SCHEMA_VERSION, s1);
 
     assertThat(sessionCache.getAttachedSessions()).containsExactly(s1);
     assertThat(sessionCache.getAllSessions()).containsExactly(s1, headSession);
 
     Session s2 = populateSession(2, 2, /* commitToStore= */ true);
     String s2Id = s2.getSessionId();
-    sessionCache.putAttached(s2Id, 2L, s2);
+    sessionCache.putAttached(s2Id, 2L, SCHEMA_VERSION, s2);
 
     assertThat(sessionCache.getAttachedSessions()).containsExactly(s1, s2);
     assertThat(sessionCache.getAllSessions()).containsExactly(s1, s2, headSession);
@@ -341,12 +355,18 @@
     StreamSession session1 =
         StreamSession.newBuilder()
             .setSessionId("stream:1")
-            .setSessionMetadata(SessionMetadata.newBuilder().setCreationTimeMillis(0L))
+            .setSessionMetadata(
+                SessionMetadata.newBuilder()
+                    .setCreationTimeMillis(0L)
+                    .setSchemaVersion(SCHEMA_VERSION))
             .build();
     StreamSession session2 =
         StreamSession.newBuilder()
             .setSessionId("stream:2")
-            .setSessionMetadata(SessionMetadata.newBuilder().setCreationTimeMillis(0L))
+            .setSessionMetadata(
+                SessionMetadata.newBuilder()
+                    .setCreationTimeMillis(0L)
+                    .setSchemaVersion(SCHEMA_VERSION))
             .build();
     Session s1 = mock(Session.class);
     when(s1.getSessionId()).thenReturn(session1.getSessionId());
@@ -426,7 +446,8 @@
 
   private void setSessions(Session... testSessions) {
     for (Session session : testSessions) {
-      sessionCache.putAttached(session.getSessionId(), /* creationTimeMillis= */ 0L, session);
+      sessionCache.putAttached(
+          session.getSessionId(), /* creationTimeMillis= */ 0L, SCHEMA_VERSION, session);
     }
   }
 }