Support TS score feedback
Adds support for periodic TS score outputs to the ChromeML API.
Also adds a TestStreamingResponder helper class to the
OnDeviceModel service support library, as this is currently
implemented in the internal repo where it prevents us from
easily changing the mojom.
Bug: b:302395507
Change-Id: I5667b9ece9901bdd472b7484d40bc5b71b6acaa2
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/5171183
Commit-Queue: Ken Rockot <rockot@google.com>
Reviewed-by: Clark DuVall <cduvall@chromium.org>
Cr-Commit-Position: refs/heads/main@{#1243677}
diff --git a/services/on_device_model/BUILD.gn b/services/on_device_model/BUILD.gn
index d470060..da903da9 100644
--- a/services/on_device_model/BUILD.gn
+++ b/services/on_device_model/BUILD.gn
@@ -60,6 +60,7 @@
deps = [
":on_device_model_service_for_test",
"//base/test:test_support",
+ "//services/on_device_model/public/cpp/test_support",
"//services/on_device_model/public/mojom",
"//testing/gmock",
"//testing/gtest",
diff --git a/services/on_device_model/ml/chrome_ml_api.h b/services/on_device_model/ml/chrome_ml_api.h
index 829b564..a8a36d6 100644
--- a/services/on_device_model/ml/chrome_ml_api.h
+++ b/services/on_device_model/ml/chrome_ml_api.h
@@ -7,6 +7,7 @@
#include <cstdint>
#include <functional>
+#include <vector>
#include "third_party/dawn/include/dawn/dawn_proc_table.h"
#include "third_party/dawn/include/dawn/webgpu.h"
@@ -85,6 +86,7 @@
size_t ts_size;
const void* ts_spm_data;
size_t ts_spm_size;
+ size_t ts_dimension;
};
// Function provided from the library that will cancel the corresponding input
@@ -97,6 +99,10 @@
// that model execution is complete.
using ChromeMLOutputFn = std::function<void(const std::optional<std::string>&)>;
+// Receives periodic updates to TS scores, per `score_ts_interval` set in
+// ChromeMLExecuteOptions.
+using ChromeMLScoreTSFn = std::function<void(const std::vector<float>&)>;
+
// Called with the number of tokens processed after a call to RunModel()
// which has the kSave ContextMode set. This will be called on the internal
// thread executing the model.
@@ -106,6 +112,8 @@
struct ChromeMLExecutionResult {
// If true, all prior output received for this model execution is effectively
// retracted by the library and should be discarded by the client.
+ //
+ // DEPRECATED: Clients should ignore this field. It will be deleted.
bool retracted;
};
@@ -115,14 +123,16 @@
std::function<void(const ChromeMLExecutionResult&)>;
struct ChromeMLExecuteOptions {
- const char* prompt = nullptr;
- int context_mode = ContextMode::kNone;
- uint32_t max_tokens = 0;
- uint32_t token_offset = 0;
- uint32_t max_output_tokens = 0;
- const ChromeMLOutputFn* output_fn = nullptr;
- const ChromeMLContextSavedFn* context_saved_fn = nullptr;
- const ChromeMLCompletionFn* completion_fn = nullptr;
+ const char* prompt;
+ int context_mode;
+ uint32_t max_tokens;
+ uint32_t token_offset;
+ uint32_t max_output_tokens;
+ uint32_t score_ts_interval;
+ const ChromeMLOutputFn* output_fn;
+ const ChromeMLScoreTSFn* score_ts_fn;
+ const ChromeMLContextSavedFn* context_saved_fn;
+ const ChromeMLCompletionFn* completion_fn;
};
// Performance data filled out by GetEstimatedPerformance().
diff --git a/services/on_device_model/on_device_model_service_unittest.cc b/services/on_device_model/on_device_model_service_unittest.cc
index 66235b7..d63d2e2 100644
--- a/services/on_device_model/on_device_model_service_unittest.cc
+++ b/services/on_device_model/on_device_model_service_unittest.cc
@@ -7,6 +7,7 @@
#include "base/test/bind.h"
#include "base/test/task_environment.h"
#include "services/on_device_model/public/cpp/model_assets.h"
+#include "services/on_device_model/public/cpp/test_support/test_response_holder.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"
@@ -15,28 +16,6 @@
using ::testing::ElementsAre;
-class ResponseHolder : public mojom::StreamingResponder {
- public:
- mojo::PendingRemote<mojom::StreamingResponder> BindRemote() {
- return receiver_.BindNewPipeAndPassRemote();
- }
-
- void OnResponse(const std::string& text) override {
- responses_.push_back(text);
- }
-
- void OnComplete(mojom::ResponseStatus status) override { run_loop_.Quit(); }
-
- void WaitForCompletion() { run_loop_.Run(); }
-
- const std::vector<std::string> responses() const { return responses_; }
-
- private:
- base::RunLoop run_loop_;
- mojo::Receiver<mojom::StreamingResponder> receiver_{this};
- std::vector<std::string> responses_;
-};
-
class ContextClientWaiter : public mojom::ContextClient {
public:
mojo::PendingRemote<mojom::ContextClient> BindRemote() {
@@ -94,7 +73,7 @@
TEST_F(OnDeviceModelServiceTest, Responds) {
auto model = LoadModel();
{
- ResponseHolder response;
+ TestResponseHolder response;
mojo::Remote<mojom::Session> session;
model->StartSession(session.BindNewPipeAndPassReceiver());
session->Execute(MakeInput("bar"), response.BindRemote());
@@ -105,7 +84,7 @@
}
// Try another input on the same model.
{
- ResponseHolder response;
+ TestResponseHolder response;
mojo::Remote<mojom::Session> session;
model->StartSession(session.BindNewPipeAndPassReceiver());
session->Execute(MakeInput("cat"), response.BindRemote());
@@ -119,7 +98,7 @@
TEST_F(OnDeviceModelServiceTest, AddContext) {
auto model = LoadModel();
- ResponseHolder response;
+ TestResponseHolder response;
mojo::Remote<mojom::Session> session;
model->StartSession(session.BindNewPipeAndPassReceiver());
session->AddContext(MakeInput("cheese"), {});
@@ -135,7 +114,7 @@
TEST_F(OnDeviceModelServiceTest, IgnoresContext) {
auto model = LoadModel();
- ResponseHolder response;
+ TestResponseHolder response;
mojo::Remote<mojom::Session> session;
model->StartSession(session.BindNewPipeAndPassReceiver());
session->AddContext(MakeInput("cheese"), {});
@@ -151,7 +130,7 @@
TEST_F(OnDeviceModelServiceTest, AddContextWithTokenLimits) {
auto model = LoadModel();
- ResponseHolder response;
+ TestResponseHolder response;
mojo::Remote<mojom::Session> session;
model->StartSession(session.BindNewPipeAndPassReceiver());
@@ -181,7 +160,7 @@
TEST_F(OnDeviceModelServiceTest, CancelsPreviousSession) {
auto model = LoadModel();
- ResponseHolder response1;
+ TestResponseHolder response1;
mojo::Remote<mojom::Session> session1;
model->StartSession(session1.BindNewPipeAndPassReceiver());
session1->Execute(MakeInput("1"), response1.BindRemote());
@@ -200,7 +179,7 @@
EXPECT_THAT(response1.responses(), ElementsAre("Input: 1\n"));
// Second session still works.
- ResponseHolder response2;
+ TestResponseHolder response2;
session2->Execute(MakeInput("2"), response2.BindRemote());
response2.WaitForCompletion();
EXPECT_THAT(response2.responses(), ElementsAre("Input: 2\n"));
diff --git a/services/on_device_model/public/cpp/test_support/BUILD.gn b/services/on_device_model/public/cpp/test_support/BUILD.gn
new file mode 100644
index 0000000..746b1cb
--- /dev/null
+++ b/services/on_device_model/public/cpp/test_support/BUILD.gn
@@ -0,0 +1,17 @@
+# Copyright 2024 The Chromium Authors
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+source_set("test_support") {
+ testonly = true
+ sources = [
+ "test_response_holder.cc",
+ "test_response_holder.h",
+ ]
+
+ public_deps = [
+ "//base",
+ "//mojo/public/cpp/bindings",
+ "//services/on_device_model/public/mojom",
+ ]
+}
diff --git a/services/on_device_model/public/cpp/test_support/test_response_holder.cc b/services/on_device_model/public/cpp/test_support/test_response_holder.cc
new file mode 100644
index 0000000..8725a30
--- /dev/null
+++ b/services/on_device_model/public/cpp/test_support/test_response_holder.cc
@@ -0,0 +1,30 @@
+// Copyright 2024 The Chromium Authors
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "services/on_device_model/public/cpp/test_support/test_response_holder.h"
+
+namespace on_device_model {
+
+TestResponseHolder::TestResponseHolder() = default;
+
+TestResponseHolder::~TestResponseHolder() = default;
+
+mojo::PendingRemote<mojom::StreamingResponder>
+TestResponseHolder::BindRemote() {
+ return receiver_.BindNewPipeAndPassRemote();
+}
+
+void TestResponseHolder::WaitForCompletion() {
+ run_loop_.Run();
+}
+
+void TestResponseHolder::OnResponse(const std::string& text) {
+ responses_.push_back(text);
+}
+
+void TestResponseHolder::OnComplete(mojom::ResponseStatus status) {
+ run_loop_.Quit();
+}
+
+} // namespace on_device_model
diff --git a/services/on_device_model/public/cpp/test_support/test_response_holder.h b/services/on_device_model/public/cpp/test_support/test_response_holder.h
new file mode 100644
index 0000000..2e7d0b1
--- /dev/null
+++ b/services/on_device_model/public/cpp/test_support/test_response_holder.h
@@ -0,0 +1,48 @@
+// Copyright 2024 The Chromium Authors
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef SERVICES_ON_DEVICE_MODEL_PUBLIC_CPP_TEST_SUPPORT_TEST_RESPONSE_HOLDER_H_
+#define SERVICES_ON_DEVICE_MODEL_PUBLIC_CPP_TEST_SUPPORT_TEST_RESPONSE_HOLDER_H_
+
+#include <string>
+#include <vector>
+
+#include "base/run_loop.h"
+#include "mojo/public/cpp/bindings/pending_remote.h"
+#include "mojo/public/cpp/bindings/receiver.h"
+#include "services/on_device_model/public/mojom/on_device_model.mojom.h"
+
+namespace on_device_model {
+
+// Helper to accumulate a streamed response from model execution. This is only
+// used by downstream clients, but is defined upstream to avoid downstream mojom
+// dependencies.
+class TestResponseHolder : public mojom::StreamingResponder {
+ public:
+ TestResponseHolder();
+ ~TestResponseHolder() override;
+
+ // Returns a remote which can be used to stream a response to this object.
+ mojo::PendingRemote<mojom::StreamingResponder> BindRemote();
+
+ // Accumulated responses so far from whoever controls the remote
+ // StreamingResponder endpoint.
+ const std::vector<std::string> responses() const { return responses_; }
+
+ // Spins a RunLoop until this object observes completion of its response.
+ void WaitForCompletion();
+
+ // mojom::StreamingResponder:
+ void OnResponse(const std::string& text) override;
+ void OnComplete(mojom::ResponseStatus status) override;
+
+ private:
+ base::RunLoop run_loop_;
+ mojo::Receiver<mojom::StreamingResponder> receiver_{this};
+ std::vector<std::string> responses_;
+};
+
+} // namespace on_device_model
+
+#endif // SERVICES_ON_DEVICE_MODEL_PUBLIC_CPP_TEST_SUPPORT_TEST_RESPONSE_HOLDER_H_