on_device_model: Add T&S support for ChromeOS
To support the T&S layer in the ChromeosPlatformModelLoader, we need to
read the ts_data & ts_sp_model from the model.json file.
BUG=b:338325476
TEST=None
Change-Id: I1dc4751c9b83709fd287190cbe51083f713ff3c4
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/5506560
Reviewed-by: Clark DuVall <cduvall@chromium.org>
Commit-Queue: Yi Chou <yich@google.com>
Cr-Commit-Position: refs/heads/main@{#1295481}
diff --git a/services/on_device_model/platform_model_loader_chromeos.cc b/services/on_device_model/platform_model_loader_chromeos.cc
index 716a2b1..7f91d692 100644
--- a/services/on_device_model/platform_model_loader_chromeos.cc
+++ b/services/on_device_model/platform_model_loader_chromeos.cc
@@ -70,6 +70,9 @@
constexpr char kModelPathKey[] = "model_path";
constexpr char kWeightPathKey[] = "weight_path";
constexpr char kSpModelPathKey[] = "sp_model_path";
+constexpr char kTsDataPathKey[] = "ts_data_path";
+constexpr char kTsSpModelPathKey[] = "ts_sp_model_path";
+constexpr char kTsDimensionKey[] = "ts_dimension";
constexpr char kVersionKey[] = "version";
constexpr int kDefaultMaxTokens = 1024;
constexpr char kLoadStatusHistogramName[] =
@@ -318,16 +321,29 @@
}
}
+ const std::string* ts_data = model_dict->FindString(kTsDataPathKey);
+ const std::string* ts_sp_model = model_dict->FindString(kTsSpModelPathKey);
+ std::optional<int> ts_dimension = model_dict->FindInt(kTsDimensionKey);
+
on_device_model::ModelAssetPaths model_paths;
model_paths.sp_model = dlc_root.Append(*sp_model);
model_paths.model = dlc_root.Append(*model_path);
model_paths.weights = dlc_root.Append(*weight_path);
+ if (ts_data) {
+ model_paths.ts_data = dlc_root.Append(*ts_data);
+ }
+ if (ts_sp_model) {
+ model_paths.ts_sp_model = dlc_root.Append(*ts_sp_model);
+ }
auto params = on_device_model::mojom::LoadModelParams::New();
params->assets = on_device_model::LoadModelAssets(model_paths);
params->max_tokens = max_tokens.value_or(kDefaultMaxTokens);
params->adaptation_ranks = adaptation_ranks;
params->support_multiple_sessions = true;
+ if (ts_dimension.has_value()) {
+ params->ts_dimension = *ts_dimension;
+ }
auto platform_model = base::MakeRefCounted<PlatformModel>();
service_->LoadModel(