Support TS score feedback

Adds support for periodic TS score outputs to the ChromeML API.

Also adds a TestStreamingResponder helper class to the
OnDeviceModel service support library, as this is currently
implemented in the internal repo where it prevents us from
easily changing the mojom.

Bug: b:302395507
Change-Id: I5667b9ece9901bdd472b7484d40bc5b71b6acaa2
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/5171183
Commit-Queue: Ken Rockot <rockot@google.com>
Reviewed-by: Clark DuVall <cduvall@chromium.org>
Cr-Commit-Position: refs/heads/main@{#1243677}
diff --git a/services/on_device_model/BUILD.gn b/services/on_device_model/BUILD.gn
index d470060..da903da9 100644
--- a/services/on_device_model/BUILD.gn
+++ b/services/on_device_model/BUILD.gn
@@ -60,6 +60,7 @@
   deps = [
     ":on_device_model_service_for_test",
     "//base/test:test_support",
+    "//services/on_device_model/public/cpp/test_support",
     "//services/on_device_model/public/mojom",
     "//testing/gmock",
     "//testing/gtest",
diff --git a/services/on_device_model/ml/chrome_ml_api.h b/services/on_device_model/ml/chrome_ml_api.h
index 829b564..a8a36d6 100644
--- a/services/on_device_model/ml/chrome_ml_api.h
+++ b/services/on_device_model/ml/chrome_ml_api.h
@@ -7,6 +7,7 @@
 
 #include <cstdint>
 #include <functional>
+#include <vector>
 
 #include "third_party/dawn/include/dawn/dawn_proc_table.h"
 #include "third_party/dawn/include/dawn/webgpu.h"
@@ -85,6 +86,7 @@
   size_t ts_size;
   const void* ts_spm_data;
   size_t ts_spm_size;
+  size_t ts_dimension;
 };
 
 // Function provided from the library that will cancel the corresponding input
@@ -97,6 +99,10 @@
 // that model execution is complete.
 using ChromeMLOutputFn = std::function<void(const std::optional<std::string>&)>;
 
+// Receives periodic updates to TS scores, per `score_ts_interval` set in
+// ChromeMLExecuteOptions.
+using ChromeMLScoreTSFn = std::function<void(const std::vector<float>&)>;
+
 // Called with the number of tokens processed after a call to RunModel()
 // which has the kSave ContextMode set. This will be called on the internal
 // thread executing the model.
@@ -106,6 +112,8 @@
 struct ChromeMLExecutionResult {
   // If true, all prior output received for this model execution is effectively
   // retracted by the library and should be discarded by the client.
+  //
+  // DEPRECATED: Clients should ignore this field. It will be deleted.
   bool retracted;
 };
 
@@ -115,14 +123,16 @@
     std::function<void(const ChromeMLExecutionResult&)>;
 
 struct ChromeMLExecuteOptions {
-  const char* prompt = nullptr;
-  int context_mode = ContextMode::kNone;
-  uint32_t max_tokens = 0;
-  uint32_t token_offset = 0;
-  uint32_t max_output_tokens = 0;
-  const ChromeMLOutputFn* output_fn = nullptr;
-  const ChromeMLContextSavedFn* context_saved_fn = nullptr;
-  const ChromeMLCompletionFn* completion_fn = nullptr;
+  const char* prompt;
+  int context_mode;
+  uint32_t max_tokens;
+  uint32_t token_offset;
+  uint32_t max_output_tokens;
+  uint32_t score_ts_interval;
+  const ChromeMLOutputFn* output_fn;
+  const ChromeMLScoreTSFn* score_ts_fn;
+  const ChromeMLContextSavedFn* context_saved_fn;
+  const ChromeMLCompletionFn* completion_fn;
 };
 
 // Performance data filled out by GetEstimatedPerformance().
diff --git a/services/on_device_model/on_device_model_service_unittest.cc b/services/on_device_model/on_device_model_service_unittest.cc
index 66235b7..d63d2e2 100644
--- a/services/on_device_model/on_device_model_service_unittest.cc
+++ b/services/on_device_model/on_device_model_service_unittest.cc
@@ -7,6 +7,7 @@
 #include "base/test/bind.h"
 #include "base/test/task_environment.h"
 #include "services/on_device_model/public/cpp/model_assets.h"
+#include "services/on_device_model/public/cpp/test_support/test_response_holder.h"
 #include "testing/gmock/include/gmock/gmock.h"
 #include "testing/gtest/include/gtest/gtest.h"
 
@@ -15,28 +16,6 @@
 
 using ::testing::ElementsAre;
 
-class ResponseHolder : public mojom::StreamingResponder {
- public:
-  mojo::PendingRemote<mojom::StreamingResponder> BindRemote() {
-    return receiver_.BindNewPipeAndPassRemote();
-  }
-
-  void OnResponse(const std::string& text) override {
-    responses_.push_back(text);
-  }
-
-  void OnComplete(mojom::ResponseStatus status) override { run_loop_.Quit(); }
-
-  void WaitForCompletion() { run_loop_.Run(); }
-
-  const std::vector<std::string> responses() const { return responses_; }
-
- private:
-  base::RunLoop run_loop_;
-  mojo::Receiver<mojom::StreamingResponder> receiver_{this};
-  std::vector<std::string> responses_;
-};
-
 class ContextClientWaiter : public mojom::ContextClient {
  public:
   mojo::PendingRemote<mojom::ContextClient> BindRemote() {
@@ -94,7 +73,7 @@
 TEST_F(OnDeviceModelServiceTest, Responds) {
   auto model = LoadModel();
   {
-    ResponseHolder response;
+    TestResponseHolder response;
     mojo::Remote<mojom::Session> session;
     model->StartSession(session.BindNewPipeAndPassReceiver());
     session->Execute(MakeInput("bar"), response.BindRemote());
@@ -105,7 +84,7 @@
   }
   // Try another input on  the same model.
   {
-    ResponseHolder response;
+    TestResponseHolder response;
     mojo::Remote<mojom::Session> session;
     model->StartSession(session.BindNewPipeAndPassReceiver());
     session->Execute(MakeInput("cat"), response.BindRemote());
@@ -119,7 +98,7 @@
 TEST_F(OnDeviceModelServiceTest, AddContext) {
   auto model = LoadModel();
 
-  ResponseHolder response;
+  TestResponseHolder response;
   mojo::Remote<mojom::Session> session;
   model->StartSession(session.BindNewPipeAndPassReceiver());
   session->AddContext(MakeInput("cheese"), {});
@@ -135,7 +114,7 @@
 TEST_F(OnDeviceModelServiceTest, IgnoresContext) {
   auto model = LoadModel();
 
-  ResponseHolder response;
+  TestResponseHolder response;
   mojo::Remote<mojom::Session> session;
   model->StartSession(session.BindNewPipeAndPassReceiver());
   session->AddContext(MakeInput("cheese"), {});
@@ -151,7 +130,7 @@
 TEST_F(OnDeviceModelServiceTest, AddContextWithTokenLimits) {
   auto model = LoadModel();
 
-  ResponseHolder response;
+  TestResponseHolder response;
   mojo::Remote<mojom::Session> session;
   model->StartSession(session.BindNewPipeAndPassReceiver());
 
@@ -181,7 +160,7 @@
 TEST_F(OnDeviceModelServiceTest, CancelsPreviousSession) {
   auto model = LoadModel();
 
-  ResponseHolder response1;
+  TestResponseHolder response1;
   mojo::Remote<mojom::Session> session1;
   model->StartSession(session1.BindNewPipeAndPassReceiver());
   session1->Execute(MakeInput("1"), response1.BindRemote());
@@ -200,7 +179,7 @@
   EXPECT_THAT(response1.responses(), ElementsAre("Input: 1\n"));
 
   // Second session still works.
-  ResponseHolder response2;
+  TestResponseHolder response2;
   session2->Execute(MakeInput("2"), response2.BindRemote());
   response2.WaitForCompletion();
   EXPECT_THAT(response2.responses(), ElementsAre("Input: 2\n"));
diff --git a/services/on_device_model/public/cpp/test_support/BUILD.gn b/services/on_device_model/public/cpp/test_support/BUILD.gn
new file mode 100644
index 0000000..746b1cb
--- /dev/null
+++ b/services/on_device_model/public/cpp/test_support/BUILD.gn
@@ -0,0 +1,17 @@
+# Copyright 2024 The Chromium Authors
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+source_set("test_support") {
+  testonly = true
+  sources = [
+    "test_response_holder.cc",
+    "test_response_holder.h",
+  ]
+
+  public_deps = [
+    "//base",
+    "//mojo/public/cpp/bindings",
+    "//services/on_device_model/public/mojom",
+  ]
+}
diff --git a/services/on_device_model/public/cpp/test_support/test_response_holder.cc b/services/on_device_model/public/cpp/test_support/test_response_holder.cc
new file mode 100644
index 0000000..8725a30
--- /dev/null
+++ b/services/on_device_model/public/cpp/test_support/test_response_holder.cc
@@ -0,0 +1,30 @@
+// Copyright 2024 The Chromium Authors
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "services/on_device_model/public/cpp/test_support/test_response_holder.h"
+
+namespace on_device_model {
+
+TestResponseHolder::TestResponseHolder() = default;
+
+TestResponseHolder::~TestResponseHolder() = default;
+
+mojo::PendingRemote<mojom::StreamingResponder>
+TestResponseHolder::BindRemote() {
+  return receiver_.BindNewPipeAndPassRemote();
+}
+
+void TestResponseHolder::WaitForCompletion() {
+  run_loop_.Run();
+}
+
+void TestResponseHolder::OnResponse(const std::string& text) {
+  responses_.push_back(text);
+}
+
+void TestResponseHolder::OnComplete(mojom::ResponseStatus status) {
+  run_loop_.Quit();
+}
+
+}  // namespace on_device_model
diff --git a/services/on_device_model/public/cpp/test_support/test_response_holder.h b/services/on_device_model/public/cpp/test_support/test_response_holder.h
new file mode 100644
index 0000000..2e7d0b1
--- /dev/null
+++ b/services/on_device_model/public/cpp/test_support/test_response_holder.h
@@ -0,0 +1,48 @@
+// Copyright 2024 The Chromium Authors
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef SERVICES_ON_DEVICE_MODEL_PUBLIC_CPP_TEST_SUPPORT_TEST_RESPONSE_HOLDER_H_
+#define SERVICES_ON_DEVICE_MODEL_PUBLIC_CPP_TEST_SUPPORT_TEST_RESPONSE_HOLDER_H_
+
+#include <string>
+#include <vector>
+
+#include "base/run_loop.h"
+#include "mojo/public/cpp/bindings/pending_remote.h"
+#include "mojo/public/cpp/bindings/receiver.h"
+#include "services/on_device_model/public/mojom/on_device_model.mojom.h"
+
+namespace on_device_model {
+
+// Helper to accumulate a streamed response from model execution. This is only
+// used by downstream clients, but is defined upstream to avoid downstream mojom
+// dependencies.
+class TestResponseHolder : public mojom::StreamingResponder {
+ public:
+  TestResponseHolder();
+  ~TestResponseHolder() override;
+
+  // Returns a remote which can be used to stream a response to this object.
+  mojo::PendingRemote<mojom::StreamingResponder> BindRemote();
+
+  // Accumulated responses so far from whoever controls the remote
+  // StreamingResponder endpoint.
+  const std::vector<std::string> responses() const { return responses_; }
+
+  // Spins a RunLoop until this object observes completion of its response.
+  void WaitForCompletion();
+
+  // mojom::StreamingResponder:
+  void OnResponse(const std::string& text) override;
+  void OnComplete(mojom::ResponseStatus status) override;
+
+ private:
+  base::RunLoop run_loop_;
+  mojo::Receiver<mojom::StreamingResponder> receiver_{this};
+  std::vector<std::string> responses_;
+};
+
+}  // namespace on_device_model
+
+#endif  // SERVICES_ON_DEVICE_MODEL_PUBLIC_CPP_TEST_SUPPORT_TEST_RESPONSE_HOLDER_H_