diff --git a/DEPS b/DEPS index 19298dc4..77c30fa 100644 --- a/DEPS +++ b/DEPS
@@ -222,7 +222,7 @@ # luci-go CIPD package version. # Make sure the revision is uploaded by infra-packagers builder. # https://ci.chromium.org/p/infra-internal/g/infra-packagers/console - 'luci_go': 'git_revision:a80cc32b17397d30b5fdb121520551cfba7089a2', + 'luci_go': 'git_revision:7d578d09794d360f422427b0158e9515835f7ff3', # This can be overridden, e.g. with custom_vars, to build clang from HEAD # instead of downloading the prebuilt pinned revision. @@ -271,7 +271,7 @@ # Three lines of non-changing comments so that # the commit queue can handle CLs rolling Skia # and whatever else without interference from each other. - 'skia_revision': '052566d8ccb72f3a6c2633d4d063ab429db8d315', + 'skia_revision': 'c9c86fef729d4b48d88d2dd7d9de3768de82a1a2', # Three lines of non-changing comments so that # the commit queue can handle CLs rolling V8 # and whatever else without interference from each other. @@ -342,7 +342,7 @@ # Three lines of non-changing comments so that # the commit queue can handle CLs rolling catapult # and whatever else without interference from each other. - 'catapult_revision': '260078ccc5d76307f130414e545624b836a2ef17', + 'catapult_revision': '5484928d765f1bd6bdf55c504ded13ee5fb8eabb', # Three lines of non-changing comments so that # the commit queue can handle CLs rolling libFuzzer # and whatever else without interference from each other. @@ -350,7 +350,7 @@ # Three lines of non-changing comments so that # the commit queue can handle CLs rolling devtools-frontend # and whatever else without interference from each other. - 'devtools_frontend_revision': 'be65734ac7f4d97f56bce0d7f89c9ed4f92ec816', + 'devtools_frontend_revision': '476c43b6f43ead25961f61343e81939b0b15dcdd', # Three lines of non-changing comments so that # the commit queue can handle CLs rolling libprotobuf-mutator # and whatever else without interference from each other. @@ -386,7 +386,7 @@ # Three lines of non-changing comments so that # the commit queue can handle CLs rolling feed # and whatever else without interference from each other. - 'dawn_revision': 'c2eccfc887def447e2f1833408674095ad7d0443', + 'dawn_revision': 'e34e059804709726e9cbd35547c3ff857924af33', # Three lines of non-changing comments so that # the commit queue can handle CLs rolling feed # and whatever else without interference from each other. @@ -767,7 +767,7 @@ }, 'src/ios/third_party/material_components_ios/src': { - 'url': Var('chromium_git') + '/external/github.com/material-components/material-components-ios.git' + '@' + '99d2e285427146c61bdb4293ffde246ee2c38117', + 'url': Var('chromium_git') + '/external/github.com/material-components/material-components-ios.git' + '@' + '3a935e0d9630a8867202958d9097acde16e38d0d', 'condition': 'checkout_ios', }, @@ -916,7 +916,7 @@ 'packages': [ { 'package': 'chromium/third_party/androidx', - 'version': 'HkgOVPvf4SUpyd0B5842wpzJlSkZqdcPuX6M2QOFemsC', + 'version': 'fgr2Q9m0sFBMGJd5Hv_GZ9US1xkOqnt33aCouiyvx80C', }, ], 'condition': 'checkout_android', @@ -1132,7 +1132,7 @@ }, 'src/third_party/depot_tools': - Var('chromium_git') + '/chromium/tools/depot_tools.git' + '@' + '8d2d507a4974d20d4cac8a8486e31a27aca8b562', + Var('chromium_git') + '/chromium/tools/depot_tools.git' + '@' + '31140af3cfe21203cd9b48135859654db4d39ca1', 'src/third_party/devtools-frontend/src': Var('chromium_git') + '/devtools/devtools-frontend' + '@' + Var('devtools_frontend_revision'), @@ -1512,7 +1512,7 @@ Var('chromium_git') + '/external/github.com/cisco/openh264' + '@' + 'fac04ceb3e966f613ed17e98178e9d690280bba6', 'src/third_party/openscreen/src': - Var('chromium_git') + '/openscreen' + '@' + 'd7b6a03bc015d1580d00cd978a1b73d30d1cc6fb', + Var('chromium_git') + '/openscreen' + '@' + 'a82e9e4df560dc770e358b907128d378e35541ca', 'src/third_party/openxr/src': { 'url': Var('chromium_git') + '/external/github.com/KhronosGroup/OpenXR-SDK' + '@' + 'bf21ccb1007bb531b45d9978919a56ea5059c245', @@ -1529,7 +1529,7 @@ }, 'src/third_party/perfetto': - Var('android_git') + '/platform/external/perfetto.git' + '@' + 'd1cb81f2aa43df0d60a387e49c7ca570b685ca7f', + Var('android_git') + '/platform/external/perfetto.git' + '@' + '5e0d3dbcc00516ba502fc5f9631cfd2136664489', 'src/third_party/perl': { 'url': Var('chromium_git') + '/chromium/deps/perl.git' + '@' + '6f3e5028eb65d0b4c5fdd792106ac4c84eee1eb3', @@ -1665,7 +1665,7 @@ 'condition': 'checkout_android', }, - 'src/third_party/vulkan-deps': '{chromium_git}/vulkan-deps@13813e2f06a1eaecf12f8a095aae27c661b32613', + 'src/third_party/vulkan-deps': '{chromium_git}/vulkan-deps@648e3cd0461952f3a03aa877c34bca2c48f5f434', 'src/third_party/vulkan_memory_allocator': Var('chromium_git') + '/external/github.com/GPUOpen-LibrariesAndSDKs/VulkanMemoryAllocator.git' + '@' + 'ebe84bec02c041d28f902da0214bf442743fc907', @@ -1704,7 +1704,7 @@ Var('chromium_git') + '/external/github.com/gpuweb/cts.git' + '@' + '5f05d6d5e625fe6f04903335473c5638ddf94514', 'src/third_party/webrtc': - Var('webrtc_git') + '/src.git' + '@' + '7fd0cb266cc2d2a883ee78461aa6e656a59166f8', + Var('webrtc_git') + '/src.git' + '@' + '65e46b93b501ed5f0c3c608652aebc31cdd1e7c6', 'src/third_party/libgifcodec': Var('skia_git') + '/libgifcodec' + '@'+ Var('libgifcodec_revision'), @@ -1734,7 +1734,7 @@ 'packages': [ { 'package': 'skia/tools/goldctl/linux-amd64', - 'version': '5TkbcvJVzHp9wQ53-1Nm90VM5uLJaVVpC_AwYtPXxGIC', + 'version': 'GsWjhyPQqwbT6QIxNBpT4HZ0vZ9GX1TJFj4xhn8vXJEC', }, ], 'dep_type': 'cipd', @@ -1777,7 +1777,7 @@ Var('chromium_git') + '/v8/v8.git' + '@' + Var('v8_revision'), 'src-internal': { - 'url': 'https://chrome-internal.googlesource.com/chrome/src-internal.git@a5d7bee73f037574cb851745f53abbf51febc267', + 'url': 'https://chrome-internal.googlesource.com/chrome/src-internal.git@11888b5b3904793cb17c407be9ae9c07cc8d034e', 'condition': 'checkout_src_internal', },
diff --git a/android_webview/browser/network_service/aw_proxying_restricted_cookie_manager.cc b/android_webview/browser/network_service/aw_proxying_restricted_cookie_manager.cc index 6943be6..2db1a63 100644 --- a/android_webview/browser/network_service/aw_proxying_restricted_cookie_manager.cc +++ b/android_webview/browser/network_service/aw_proxying_restricted_cookie_manager.cc
@@ -213,4 +213,10 @@ } } +void AwProxyingRestrictedCookieManager:: + ConvertPartitionedCookiesToUnpartitioned(const GURL& url) { + underlying_restricted_cookie_manager_ + ->ConvertPartitionedCookiesToUnpartitioned(url); +} + } // namespace android_webview
diff --git a/android_webview/browser/network_service/aw_proxying_restricted_cookie_manager.h b/android_webview/browser/network_service/aw_proxying_restricted_cookie_manager.h index 83a9479e..5ccede3d 100644 --- a/android_webview/browser/network_service/aw_proxying_restricted_cookie_manager.h +++ b/android_webview/browser/network_service/aw_proxying_restricted_cookie_manager.h
@@ -83,6 +83,9 @@ bool AllowCookies(const GURL& url, const net::SiteForCookies& site_for_cookies) const; + // TODO(https://crbug.com/1296161): Delete this function. + void ConvertPartitionedCookiesToUnpartitioned(const GURL& url) override; + private: AwProxyingRestrictedCookieManager( mojo::PendingRemote<network::mojom::RestrictedCookieManager>
diff --git a/ash/app_list/app_list_controller_impl_unittest.cc b/ash/app_list/app_list_controller_impl_unittest.cc index 09453e4..23ef483 100644 --- a/ash/app_list/app_list_controller_impl_unittest.cc +++ b/ash/app_list/app_list_controller_impl_unittest.cc
@@ -1311,6 +1311,9 @@ } TEST_F(AppListControllerImplAppListBubbleTest, HideContinueSectionUpdatesPref) { + base::test::ScopedFeatureList feature_list( + features::kLauncherHideContinueSection); + auto* controller = Shell::Get()->app_list_controller(); PrefService* prefs = Shell::Get()->session_controller()->GetLastActiveUserPrefService();
diff --git a/ash/app_list/app_list_metrics_unittest.cc b/ash/app_list/app_list_metrics_unittest.cc index 0ff07b8..ac28500 100644 --- a/ash/app_list/app_list_metrics_unittest.cc +++ b/ash/app_list/app_list_metrics_unittest.cc
@@ -575,6 +575,8 @@ TEST_F(AppListMetricsProductivityLauncherTest, HideContinueSectionMetricInClamshellMode) { + base::test::ScopedFeatureList feature_list( + features::kLauncherHideContinueSection); base::HistogramTester histograms; // Show the app list with a full continue section. @@ -604,6 +606,8 @@ TEST_F(AppListMetricsProductivityLauncherTest, HideContinueSectionMetricInTabletMode) { + base::test::ScopedFeatureList feature_list( + features::kLauncherHideContinueSection); base::HistogramTester histograms; // Show the tablet mode app list with a full continue section.
diff --git a/ash/app_list/app_list_util.cc b/ash/app_list/app_list_util.cc index cedd1cc..3fe0a51 100644 --- a/ash/app_list/app_list_util.cc +++ b/ash/app_list/app_list_util.cc
@@ -9,6 +9,7 @@ #include "ash/constants/ash_constants.h" #include "ash/constants/ash_features.h" #include "ash/public/cpp/app_list/app_list_color_provider.h" +#include "ash/style/ash_color_provider.h" #include "ui/events/event.h" #include "ui/gfx/canvas.h" #include "ui/gfx/geometry/rect.h" @@ -126,14 +127,13 @@ return true; } -gfx::ImageSkia CreateIconWithCircleBackground(const gfx::ImageSkia& icon, - SkColor background_color) { +gfx::ImageSkia CreateIconWithCircleBackground(const gfx::ImageSkia& icon) { DCHECK_EQ(icon.width(), icon.height()); - // TODO(crbug.com/1185943): We should not be passing in hardcoded - // `background_color`s here. Callers should be updated to use the appropriate - // color from the NativeTheme or AshColorProvider. return gfx::ImageSkiaOperations::CreateImageWithCircleBackground( - icon.width() / 2, background_color, icon); + icon.width() / 2, + AshColorProvider::Get()->GetBaseLayerColor( + AshColorProvider::BaseLayerType::kOpaque), + icon); } void PaintFocusBar(gfx::Canvas* canvas,
diff --git a/ash/app_list/app_list_util.h b/ash/app_list/app_list_util.h index 2994a3a7..c0f46a2 100644 --- a/ash/app_list/app_list_util.h +++ b/ash/app_list/app_list_util.h
@@ -64,8 +64,7 @@ // Returns a new image with the `icon` atop a circle background with // `background_color`. ASH_EXPORT gfx::ImageSkia CreateIconWithCircleBackground( - const gfx::ImageSkia& icon, - SkColor background_color); + const gfx::ImageSkia& icon); // Paints a rounded focus bar on `canvas` starting at `content_origin` extending // `height` dips vertically.
diff --git a/ash/app_list/views/app_list_bubble_apps_page_unittest.cc b/ash/app_list/views/app_list_bubble_apps_page_unittest.cc index cf14aa2..fa06601 100644 --- a/ash/app_list/views/app_list_bubble_apps_page_unittest.cc +++ b/ash/app_list/views/app_list_bubble_apps_page_unittest.cc
@@ -261,6 +261,9 @@ } TEST_F(AppListBubbleAppsPageTest, CanHideContinueSection) { + base::test::ScopedFeatureList feature_list( + features::kLauncherHideContinueSection); + // Show the app list with enough items to make the continue section and // recent apps visible. auto* helper = GetAppListTestHelper(); @@ -285,6 +288,9 @@ } TEST_F(AppListBubbleAppsPageTest, CanShowContinueSectionByClickingButton) { + base::test::ScopedFeatureList feature_list( + features::kLauncherHideContinueSection); + // Simulate a user with the continue section hidden on startup. Shell::Get()->app_list_controller()->SetHideContinueSection(true);
diff --git a/ash/app_list/views/apps_container_view_unittest.cc b/ash/app_list/views/apps_container_view_unittest.cc index 197af7c..743c3c9f 100644 --- a/ash/app_list/views/apps_container_view_unittest.cc +++ b/ash/app_list/views/apps_container_view_unittest.cc
@@ -18,7 +18,10 @@ class AppsContainerViewTest : public AshTestBase { public: AppsContainerViewTest() { - features_.InitAndEnableFeature(features::kProductivityLauncher); + // These tests primarily exercise the "hide continue section" behavior. + features_.InitWithFeatures({features::kProductivityLauncher, + features::kLauncherHideContinueSection}, + {}); } ~AppsContainerViewTest() override = default;
diff --git a/ash/app_list/views/continue_section_view_unittest.cc b/ash/app_list/views/continue_section_view_unittest.cc index d6e757a9..5bc19f3 100644 --- a/ash/app_list/views/continue_section_view_unittest.cc +++ b/ash/app_list/views/continue_section_view_unittest.cc
@@ -1492,6 +1492,9 @@ // when this feature works in tablet mode. TEST_F(ContinueSectionViewClamshellModeTest, HidingContinueSectionHidesPrivacyNotice) { + base::test::ScopedFeatureList feature_list( + features::kLauncherHideContinueSection); + AddSearchResult("id1", AppListSearchResultType::kZeroStateFile); AddSearchResult("id2", AppListSearchResultType::kZeroStateDrive); AddSearchResult("id3", AppListSearchResultType::kZeroStateDrive);
diff --git a/ash/app_list/views/search_result_tile_item_view.cc b/ash/app_list/views/search_result_tile_item_view.cc index b50a1be..7aa6f46 100644 --- a/ash/app_list/views/search_result_tile_item_view.cc +++ b/ash/app_list/views/search_result_tile_item_view.cc
@@ -397,8 +397,7 @@ gfx::ImageSkia badge_icon_skia = badge_icon.Rasterize(GetColorProvider()); if (use_badge_icon_background) { - badge_icon_skia = - CreateIconWithCircleBackground(badge_icon_skia, SK_ColorWHITE); + badge_icon_skia = CreateIconWithCircleBackground(badge_icon_skia); } gfx::ImageSkia resized_badge_icon(
diff --git a/ash/app_list/views/search_result_view.cc b/ash/app_list/views/search_result_view.cc index 2d8f096..dc41742 100644 --- a/ash/app_list/views/search_result_view.cc +++ b/ash/app_list/views/search_result_view.cc
@@ -722,8 +722,7 @@ result()->badge_icon().Rasterize(GetColorProvider()); if (result()->use_badge_icon_background()) { - badge_icon_skia = - CreateIconWithCircleBackground(badge_icon_skia, SK_ColorWHITE); + badge_icon_skia = CreateIconWithCircleBackground(badge_icon_skia); } gfx::ImageSkia resized_badge_icon(
diff --git a/ash/ash_strings.grd b/ash/ash_strings.grd index 176fff7..e4e14ee 100644 --- a/ash/ash_strings.grd +++ b/ash/ash_strings.grd
@@ -1907,6 +1907,15 @@ <message name="IDS_ASH_DESKS_TEMPLATES_LIBRARY_NO_DESKS_LABEL" desc="The text that shows in the saved desk library when there are no saved desks."> No saved desks </message> + <message name="IDS_ASH_DESKS_TEMPLATES_LIBRARY_TEMPLATES_GRID_ITEM_ACCESSIBLE_NAME" desc="The full accessible name of template grid item"> + Template, <ph name="TEMPLATE_NAME">$1</ph> + </message> + <message name="IDS_ASH_DESKS_TEMPLATES_LIBRARY_SAVE_AND_RECALL_GRID_ITEM_ACCESSIBLE_NAME" desc="The full accessible name of save and recall grid item"> + Saved desk, <ph name="SAVE_AND_RECALL_DESK_NAME">$1</ph> + </message> + <message name="IDS_ASH_DESKS_TEMPLATES_LIBRARY_SAVED_DESK_GRID_ITEM_EXTRA_ACCESSIBLE_DESCRIPTION" desc="The extra accessible description of saved desk grid item"> + Press Ctrl+W to close + </message> <!-- Virtual Desks --> <message name="IDS_ASH_DESKS_NEW_DESK_BUTTON" desc="The label of the new virtual desk (a.k.a. workspaces) button.">
diff --git a/ash/ash_strings_grd/IDS_ASH_DESKS_TEMPLATES_LIBRARY_SAVED_DESK_GRID_ITEM_EXTRA_ACCESSIBLE_DESCRIPTION.png.sha1 b/ash/ash_strings_grd/IDS_ASH_DESKS_TEMPLATES_LIBRARY_SAVED_DESK_GRID_ITEM_EXTRA_ACCESSIBLE_DESCRIPTION.png.sha1 new file mode 100644 index 0000000..644fd0d --- /dev/null +++ b/ash/ash_strings_grd/IDS_ASH_DESKS_TEMPLATES_LIBRARY_SAVED_DESK_GRID_ITEM_EXTRA_ACCESSIBLE_DESCRIPTION.png.sha1
@@ -0,0 +1 @@ +221ae7039a68577a32f0ed4c733d142de65cafa8 \ No newline at end of file
diff --git a/ash/ash_strings_grd/IDS_ASH_DESKS_TEMPLATES_LIBRARY_SAVE_AND_RECALL_GRID_ITEM_ACCESSIBLE_NAME.png.sha1 b/ash/ash_strings_grd/IDS_ASH_DESKS_TEMPLATES_LIBRARY_SAVE_AND_RECALL_GRID_ITEM_ACCESSIBLE_NAME.png.sha1 new file mode 100644 index 0000000..9d2b73f --- /dev/null +++ b/ash/ash_strings_grd/IDS_ASH_DESKS_TEMPLATES_LIBRARY_SAVE_AND_RECALL_GRID_ITEM_ACCESSIBLE_NAME.png.sha1
@@ -0,0 +1 @@ +ae9c08dd236659dd6e4a867a73642c614598fa13 \ No newline at end of file
diff --git a/ash/ash_strings_grd/IDS_ASH_DESKS_TEMPLATES_LIBRARY_TEMPLATES_GRID_ITEM_ACCESSIBLE_NAME.png.sha1 b/ash/ash_strings_grd/IDS_ASH_DESKS_TEMPLATES_LIBRARY_TEMPLATES_GRID_ITEM_ACCESSIBLE_NAME.png.sha1 new file mode 100644 index 0000000..aa6b9cd --- /dev/null +++ b/ash/ash_strings_grd/IDS_ASH_DESKS_TEMPLATES_LIBRARY_TEMPLATES_GRID_ITEM_ACCESSIBLE_NAME.png.sha1
@@ -0,0 +1 @@ +44bd8f723cd85e4e7e4451a51cd80fbbc36579b9 \ No newline at end of file
diff --git a/ash/components/arc/ime/arc_ime_service.cc b/ash/components/arc/ime/arc_ime_service.cc index 50328f2..bab7a1e2 100644 --- a/ash/components/arc/ime/arc_ime_service.cc +++ b/ash/components/arc/ime/arc_ime_service.cc
@@ -646,8 +646,8 @@ return false; } -absl::optional<ui::GrammarFragment> -ArcImeService::GetGrammarFragmentAtCursor() { +absl::optional<ui::GrammarFragment> ArcImeService::GetGrammarFragmentAtCursor() + const { // TODO(https://crbug.com/1201454): Implement this method. NOTIMPLEMENTED_LOG_ONCE(); return absl::nullopt;
diff --git a/ash/components/arc/ime/arc_ime_service.h b/ash/components/arc/ime/arc_ime_service.h index 733c8a4..f69d152 100644 --- a/ash/components/arc/ime/arc_ime_service.h +++ b/ash/components/arc/ime/arc_ime_service.h
@@ -154,7 +154,8 @@ gfx::Range GetAutocorrectRange() const override; gfx::Rect GetAutocorrectCharacterBounds() const override; bool SetAutocorrectRange(const gfx::Range& range) override; - absl::optional<ui::GrammarFragment> GetGrammarFragmentAtCursor() override; + absl::optional<ui::GrammarFragment> GetGrammarFragmentAtCursor() + const override; bool ClearGrammarFragments(const gfx::Range& range) override; bool AddGrammarFragments( const std::vector<ui::GrammarFragment>& fragments) override;
diff --git a/ash/components/arc/video_accelerator/OWNERS b/ash/components/arc/video_accelerator/OWNERS index 8fb3221c8..8a7d101 100644 --- a/ash/components/arc/video_accelerator/OWNERS +++ b/ash/components/arc/video_accelerator/OWNERS
@@ -1,5 +1,4 @@ akahuang@chromium.org acourbot@chromium.org -dstaessens@chromium.org hiroh@chromium.org posciak@chromium.org
diff --git a/ash/constants/ash_features.cc b/ash/constants/ash_features.cc index 66e1b55..58ba7cf 100644 --- a/ash/constants/ash_features.cc +++ b/ash/constants/ash_features.cc
@@ -29,6 +29,12 @@ const base::Feature kAdaptiveCharging{"AdaptiveCharging", base::FEATURE_DISABLED_BY_DEFAULT}; +// Enable the logic to show the notifications for Adaptive Charging features. +// This is intended to be used by developers to test the UI aspect of the +// feature. +const base::Feature kAdaptiveChargingForTesting{ + "AdaptiveChargingForTesting", base::FEATURE_DISABLED_BY_DEFAULT}; + // Adjusts portrait mode split view to avoid the input field in the bottom // window being occluded by the virtual keyboard. const base::Feature kAdjustSplitViewForVK{"AdjustSplitViewForVK", @@ -952,7 +958,7 @@ // and recent apps context menu that allow the user to hide the continue // section. const base::Feature kLauncherHideContinueSection{ - "LauncherHideContinueSection", base::FEATURE_ENABLED_BY_DEFAULT}; + "LauncherHideContinueSection", base::FEATURE_DISABLED_BY_DEFAULT}; // Uses short intervals for launcher nudge for testing if enabled. const base::Feature kLauncherNudgeShortInterval{ @@ -1607,6 +1613,10 @@ return base::FeatureList::IsEnabled(kAdaptiveCharging); } +bool IsAdaptiveChargingForTestingEnabled() { + return base::FeatureList::IsEnabled(kAdaptiveChargingForTesting); +} + bool IsAdjustSplitViewForVKEnabled() { return base::FeatureList::IsEnabled(kAdjustSplitViewForVK); }
diff --git a/ash/constants/ash_features.h b/ash/constants/ash_features.h index 4afa3403..4550a491 100644 --- a/ash/constants/ash_features.h +++ b/ash/constants/ash_features.h
@@ -20,6 +20,8 @@ COMPONENT_EXPORT(ASH_CONSTANTS) extern const base::Feature kAdaptiveCharging; COMPONENT_EXPORT(ASH_CONSTANTS) +extern const base::Feature kAdaptiveChargingForTesting; +COMPONENT_EXPORT(ASH_CONSTANTS) extern const base::Feature kAdjustSplitViewForVK; COMPONENT_EXPORT(ASH_CONSTANTS) extern const base::Feature kAllowAmbientEQ; COMPONENT_EXPORT(ASH_CONSTANTS) @@ -617,6 +619,7 @@ COMPONENT_EXPORT(ASH_CONSTANTS) bool AreImprovedScreenCaptureSettingsEnabled(); COMPONENT_EXPORT(ASH_CONSTANTS) bool DoWindowsFollowCursor(); COMPONENT_EXPORT(ASH_CONSTANTS) bool IsAdaptiveChargingEnabled(); +COMPONENT_EXPORT(ASH_CONSTANTS) bool IsAdaptiveChargingForTestingEnabled(); COMPONENT_EXPORT(ASH_CONSTANTS) bool IsAdjustSplitViewForVKEnabled(); COMPONENT_EXPORT(ASH_CONSTANTS) bool IsAllowAmbientEQEnabled(); COMPONENT_EXPORT(ASH_CONSTANTS) bool IsAmbientModeAnimationEnabled();
diff --git a/ash/display/window_tree_host_manager_unittest.cc b/ash/display/window_tree_host_manager_unittest.cc index 532f1278..8ddf7c7 100644 --- a/ash/display/window_tree_host_manager_unittest.cc +++ b/ash/display/window_tree_host_manager_unittest.cc
@@ -1791,24 +1791,15 @@ Shell::Get()->window_tree_host_manager()->GetRootWindowForDisplayId( GetPrimaryDisplay().id()); - // Creating a window at the lower right end of the primary display with a - // view. - views::Widget* widget = TestWidgetBuilder().BuildOwnedByNativeWidget(); - views::View* view = - widget->GetContentsView()->AddChildView(std::make_unique<views::View>()); - - view->SetBounds(0, 0, 200, 200); - widget->Show(); - RootWindowTestEventHandler handler; root_window->AddPreTargetHandler(&handler); - ui::test::EventGenerator generator(root_window, widget->GetNativeWindow()); + ui::test::EventGenerator generator(root_window); - // Move the cursor to a coordinates that is in the logical bounds of the older - // display[0,0 800x600] and in the physical bounds[0,0 700x500] of new display - // but not in the logical bound[0,0 350x250] as the crash happens before full - // propagation of device_scale_factor effect. + // Move the cursor to a coordinate that is in the logical bounds of the older + // display[0,0 800x600] but not in the logical bounds[0,0 350x250] of the new + // display. The cursor coordinates also needs to be in the root window's + // bounds[0,0 700x500] attached to the new primary display. generator.MoveMouseTo(400, 300); // Replace the primary display with a newer display with a different device @@ -1823,7 +1814,6 @@ display_manager()->OnNativeDisplaysChanged(display_info_list); root_window->RemovePreTargetHandler(&handler); - widget->CloseNow(); } TEST_F(WindowTreeHostManagerTest, KeyEventFromSecondaryDisplay) {
diff --git a/ash/login/ui/lock_screen_media_controls_view_unittest.cc b/ash/login/ui/lock_screen_media_controls_view_unittest.cc index 93b864e..eccd6cf9 100644 --- a/ash/login/ui/lock_screen_media_controls_view_unittest.cc +++ b/ash/login/ui/lock_screen_media_controls_view_unittest.cc
@@ -681,8 +681,12 @@ SimulateMediaSessionChanged( media_session::mojom::MediaPlaybackState::kPlaying); + const bool should_use_dark_color = + features::IsDarkLightModeEnabled() && + AshColorProvider::Get()->IsDarkModeEnabled(); gfx::ImageSkia default_icon = gfx::CreateVectorIcon( - message_center::kProductIcon, kAppIconSize, gfx::kGoogleGrey700); + message_center::kProductIcon, kAppIconSize, + should_use_dark_color ? gfx::kGoogleGrey500 : gfx::kGoogleGrey700); // Verify that the icon is initialized to the default. EXPECT_TRUE(icon_view()->GetImage().BackedBySameObjectAs(default_icon));
diff --git a/ash/system/accessibility/dictation_bubble_controller_unittest.cc b/ash/system/accessibility/dictation_bubble_controller_unittest.cc index 9e4e1b9..454bd23 100644 --- a/ash/system/accessibility/dictation_bubble_controller_unittest.cc +++ b/ash/system/accessibility/dictation_bubble_controller_unittest.cc
@@ -26,6 +26,8 @@ // AshTestBase: void SetUp() override { + scoped_feature_list_.InitAndDisableFeature( + chromeos::features::kDarkLightMode); AshTestBase::SetUp(); Shell::Get()->accessibility_controller()->dictation().SetEnabled(true); } @@ -92,6 +94,9 @@ std::vector<std::u16string> GetVisibleHints() { return GetView()->GetVisibleHintsForTesting(); } + + private: + base::test::ScopedFeatureList scoped_feature_list_; }; TEST_F(DictationBubbleControllerTest, ShowText) { @@ -155,8 +160,7 @@ // Verifies text and icon colors when the dark light mode feature is disabled. TEST_F(DictationBubbleControllerTest, NoDarkMode) { - base::test::ScopedFeatureList scoped_feature_list; - scoped_feature_list.InitAndDisableFeature(chromeos::features::kDarkLightMode); + ASSERT_FALSE(chromeos::features::IsDarkLightModeEnabled()); // Show bubble UI. EXPECT_FALSE(GetView());
diff --git a/ash/system/power/adaptive_charging_controller.cc b/ash/system/power/adaptive_charging_controller.cc index 4aba183..ba58ec55 100644 --- a/ash/system/power/adaptive_charging_controller.cc +++ b/ash/system/power/adaptive_charging_controller.cc
@@ -4,12 +4,22 @@ #include "ash/system/power/adaptive_charging_controller.h" +#include "ash/constants/ash_features.h" #include "base/scoped_observation.h" #include "chromeos/dbus/power/power_manager_client.h" #include "chromeos/dbus/power_manager/power_supply_properties.pb.h" namespace ash { +namespace { + +#if DCHECK_IS_ON() +// Fake input for notification testing. +constexpr int kFakeNotificationInputForTesting = 8; +#endif // DCHECK_IS_ON() + +} // namespace + AdaptiveChargingController::AdaptiveChargingController() : nudge_controller_(std::make_unique<AdaptiveChargingNudgeController>()), notification_controller_( @@ -29,6 +39,23 @@ void AdaptiveChargingController::PowerChanged( const power_manager::PowerSupplyProperties& proto) { +#if DCHECK_IS_ON() + if (features::IsAdaptiveChargingForTestingEnabled()) { + bool is_on_charger_now = false; + if (proto.has_external_power()) { + is_on_charger_now = + proto.external_power() == power_manager::PowerSupplyProperties::AC; + } + if (!is_on_charger_ && is_on_charger_now) { + nudge_controller_->ShowNudgeForTesting(); // IN-TEST + notification_controller_->ShowAdaptiveChargingNotification( + kFakeNotificationInputForTesting); + } + is_on_charger_ = is_on_charger_now; + return; + } +#endif // DCHECK_IS_ON() + // Return if this change does not contain any adaptive_delaying_charge info. if (!proto.has_adaptive_delaying_charge()) return;
diff --git a/ash/system/power/adaptive_charging_controller.h b/ash/system/power/adaptive_charging_controller.h index 21b84cee..5d2b415 100644 --- a/ash/system/power/adaptive_charging_controller.h +++ b/ash/system/power/adaptive_charging_controller.h
@@ -39,6 +39,7 @@ void PowerChanged(const power_manager::PowerSupplyProperties& proto) override; bool is_adaptive_delaying_charge_ = false; + bool is_on_charger_ = false; base::ScopedObservation<chromeos::PowerManagerClient, chromeos::PowerManagerClient::Observer>
diff --git a/ash/system/power/adaptive_charging_nudge_controller.cc b/ash/system/power/adaptive_charging_nudge_controller.cc index ae84379..e2c32dd 100644 --- a/ash/system/power/adaptive_charging_nudge_controller.cc +++ b/ash/system/power/adaptive_charging_nudge_controller.cc
@@ -52,6 +52,12 @@ weak_ptr_factory_.GetWeakPtr())); } +#if DCHECK_IS_ON() +void AdaptiveChargingNudgeController::ShowNudgeForTesting() { + SystemNudgeController::ShowNudge(); +} +#endif // DCHECK_IS_ON() + std::unique_ptr<SystemNudge> AdaptiveChargingNudgeController::CreateSystemNudge() { return std::make_unique<AdaptiveChargingNudge>();
diff --git a/ash/system/power/adaptive_charging_nudge_controller.h b/ash/system/power/adaptive_charging_nudge_controller.h index 01a306d..01bd107 100644 --- a/ash/system/power/adaptive_charging_nudge_controller.h +++ b/ash/system/power/adaptive_charging_nudge_controller.h
@@ -34,6 +34,12 @@ return nudge_delay_timer_.get(); } +#if DCHECK_IS_ON() + // This is intended to be used by developers to test the UI of the adaptive + // charging feature. + void ShowNudgeForTesting(); +#endif // DCHECK_IS_ON() + private: // SystemNudgeController: std::unique_ptr<SystemNudge> CreateSystemNudge() override;
diff --git a/ash/webui/personalization_app/mojom/BUILD.gn b/ash/webui/personalization_app/mojom/BUILD.gn index 1608227..a1994d8 100644 --- a/ash/webui/personalization_app/mojom/BUILD.gn +++ b/ash/webui/personalization_app/mojom/BUILD.gn
@@ -14,6 +14,7 @@ public_deps = [ "//mojo/public/mojom/base", + "//skia/public/mojom", "//url/mojom:url_mojom_gurl", ]
diff --git a/ash/webui/personalization_app/mojom/personalization_app.mojom b/ash/webui/personalization_app/mojom/personalization_app.mojom index aa9fa6d..b780a66 100644 --- a/ash/webui/personalization_app/mojom/personalization_app.mojom +++ b/ash/webui/personalization_app/mojom/personalization_app.mojom
@@ -7,6 +7,7 @@ import "mojo/public/mojom/base/big_buffer.mojom"; import "mojo/public/mojom/base/file_path.mojom"; import "mojo/public/mojom/base/string16.mojom"; +import "skia/public/mojom/skcolor.mojom"; import "url/mojom/url.mojom"; // This should be kept in sync with |ash::WallpaperLayout| @@ -560,6 +561,9 @@ interface KeyboardBacklightObserver { // Notifies the JS side about the current state of the backlight color. OnBacklightColorChanged(BacklightColor color); + + // Notifies the JS side the current wallpaper-extracted color. + OnWallpaperColorChanged(skia.mojom.SkColor wallpaper_color); }; // Provides APIs to expose keyboard backlight settings.
diff --git a/ash/webui/personalization_app/personalization_app_ui.cc b/ash/webui/personalization_app/personalization_app_ui.cc index 06975d6..b41365d 100644 --- a/ash/webui/personalization_app/personalization_app_ui.cc +++ b/ash/webui/personalization_app/personalization_app_ui.cc
@@ -132,6 +132,10 @@ IDS_PERSONALIZATION_APP_ARIA_LABEL_CURRENT_AVATAR}, {"ariaAnnounceAvatarChanged", IDS_PERSONALIZATION_APP_ARIA_ANNOUNCE_AVATAR_CHANGED}, + {"ariaLabelCloseCamera", + IDS_PERSONALIZATION_APP_AVATAR_ARIA_LABEL_CLOSE_CAMERA}, + {"ariaLabelWebcamVideo", + IDS_PERSONALIZATION_APP_AVATAR_ARIA_LABEL_WEBCAM_VIDEO}, // Ambient mode related string. {"screensaverLabel", IDS_PERSONALIZATION_APP_SCREENSAVER_LABEL},
diff --git a/ash/webui/personalization_app/resources/trusted/keyboard_backlight/keyboard_backlight_actions.ts b/ash/webui/personalization_app/resources/trusted/keyboard_backlight/keyboard_backlight_actions.ts index 92d7c7a..ea79bcc 100644 --- a/ash/webui/personalization_app/resources/trusted/keyboard_backlight/keyboard_backlight_actions.ts +++ b/ash/webui/personalization_app/resources/trusted/keyboard_backlight/keyboard_backlight_actions.ts
@@ -3,24 +3,33 @@ // found in the LICENSE file. import {Action} from 'chrome://resources/js/cr/ui/store.js'; +import {SkColor} from 'chrome://resources/mojo/skia/public/mojom/skcolor.mojom-webui.js'; import {BacklightColor} from '../personalization_app.mojom-webui.js'; + /** * @fileoverview Defines the actions to update keyboard backlight settings. */ export enum KeyboardBacklightActionName { SET_BACKLIGHT_COLOR = 'set_backlight_color', + SET_WALLPAPER_COLOR = 'set_wallpaper_color', } -export type KeyboardBacklightActions = SetBacklightColorAction; +export type KeyboardBacklightActions = + SetBacklightColorAction|SetWallpaperColorAction; export type SetBacklightColorAction = Action&{ name: KeyboardBacklightActionName.SET_BACKLIGHT_COLOR, backlightColor: BacklightColor, }; +export type SetWallpaperColorAction = Action&{ + name: KeyboardBacklightActionName.SET_WALLPAPER_COLOR, + wallpaperColor: SkColor, +}; + /** * Sets the current value of the backlight color. */ @@ -31,3 +40,14 @@ backlightColor }; } + +/** + * Sets the current value of the wallpaper extracted color. + */ +export function setWallpaperColorAction(wallpaperColor: SkColor): + SetWallpaperColorAction { + return { + name: KeyboardBacklightActionName.SET_WALLPAPER_COLOR, + wallpaperColor + }; +}
diff --git a/ash/webui/personalization_app/resources/trusted/keyboard_backlight/keyboard_backlight_element.html b/ash/webui/personalization_app/resources/trusted/keyboard_backlight/keyboard_backlight_element.html index e470911..7a5ead34 100644 --- a/ash/webui/personalization_app/resources/trusted/keyboard_backlight/keyboard_backlight_element.html +++ b/ash/webui/personalization_app/resources/trusted/keyboard_backlight/keyboard_backlight_element.html
@@ -98,7 +98,7 @@ aria-label="$i18n{wallpaperColor}" aria-selected$="[[getWallpaperColorAriaSelected_(backlightColor_)]]"> <div class="color-inner-container" - style$="[[getColorInnerContainerStyle_(wallpaperColorId_, presetColors_)]]"> + style$="[[getWallpaperColorInnerContainerStyle_(wallpaperColor_)]]"> <paper-ripple class="circle"></paper-ripple> <iron-icon icon="personalization:auto"></iron-icon> </div>
diff --git a/ash/webui/personalization_app/resources/trusted/keyboard_backlight/keyboard_backlight_element.ts b/ash/webui/personalization_app/resources/trusted/keyboard_backlight/keyboard_backlight_element.ts index 70ba957..adc1b65e8 100644 --- a/ash/webui/personalization_app/resources/trusted/keyboard_backlight/keyboard_backlight_element.ts +++ b/ash/webui/personalization_app/resources/trusted/keyboard_backlight/keyboard_backlight_element.ts
@@ -9,6 +9,7 @@ import '../cros_button_style.js'; import {assert} from 'chrome://resources/js/assert_ts.js'; +import {SkColor} from 'chrome://resources/mojo/skia/public/mojom/skcolor.mojom-webui.js'; import {IronA11yKeysElement} from 'chrome://resources/polymer/v3_0/iron-a11y-keys/iron-a11y-keys.js'; import {IronSelectorElement} from 'chrome://resources/polymer/v3_0/iron-selector/iron-selector.js'; @@ -76,6 +77,9 @@ /** The selected backlight color in the system. */ backlightColor_: Object, + + /** The current wallpaper extracted color. */ + wallpaperColor_: Object, }; } @@ -85,6 +89,7 @@ private wallpaperColorId_: string; private ironSelectedColor_: HTMLElement; private backlightColor_: BacklightColor|null; + private wallpaperColor_: SkColor|null; override ready() { super.ready(); @@ -96,6 +101,8 @@ KeyboardBacklightObserver.initKeyboardBacklightObserverIfNeeded(); this.watch<KeyboardBacklight['backlightColor_']>( 'backlightColor_', state => state.keyboardBacklight.backlightColor); + this.watch<KeyboardBacklight['wallpaperColor_']>( + 'wallpaperColor_', state => state.keyboardBacklight.wallpaperColor); this.updateFromStore(); } @@ -223,8 +230,6 @@ const hexColors = Object.values(colors).map(color => color.hexVal).slice(1); return `background-image: linear-gradient(${hexColors})`; - case this.wallpaperColorId_: - return `background-color: #8AB4F8`; case 'whiteColor': // Add the border for the white background. return `background-color: ${ @@ -235,6 +240,19 @@ } } + private getWallpaperColorInnerContainerStyle_(wallpaperColor: SkColor): + string { + // Show the default style when wallpaper color is loading or invalid. + if (!wallpaperColor || !wallpaperColor.value) { + return `background-color: #FFFFFF; + border: 1px solid var(--cros-separator-color);`; + } + // Strip the alpha value and convert to hex string. + const hexStr = + (wallpaperColor.value & 0xFFFFFF).toString(16).padStart(6, '0'); + return `background-color: #${hexStr};`; + } + private getPresetColorAriaLabel_(presetColorId: string): string { return this.i18n(presetColorId); }
diff --git a/ash/webui/personalization_app/resources/trusted/keyboard_backlight/keyboard_backlight_observer.ts b/ash/webui/personalization_app/resources/trusted/keyboard_backlight/keyboard_backlight_observer.ts index d425d850..7d67211 100644 --- a/ash/webui/personalization_app/resources/trusted/keyboard_backlight/keyboard_backlight_observer.ts +++ b/ash/webui/personalization_app/resources/trusted/keyboard_backlight/keyboard_backlight_observer.ts
@@ -2,10 +2,12 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. +import {SkColor} from 'chrome://resources/mojo/skia/public/mojom/skcolor.mojom-webui.js'; + import {BacklightColor, KeyboardBacklightObserverInterface, KeyboardBacklightObserverReceiver, KeyboardBacklightProviderInterface} from '../personalization_app.mojom-webui.js'; import {PersonalizationStore} from '../personalization_store.js'; -import {setBacklightColorAction} from './keyboard_backlight_actions.js'; +import {setBacklightColorAction, setWallpaperColorAction} from './keyboard_backlight_actions.js'; import {getKeyboardBacklightProvider} from './keyboard_backlight_interface_provider.js'; /** @fileoverview listens for updates on keyboard backlight settings changes. */ @@ -47,4 +49,9 @@ const store = PersonalizationStore.getInstance(); store.dispatch(setBacklightColorAction(backlightColor)); } + + onWallpaperColorChanged(wallpaperColor: SkColor): void { + const store = PersonalizationStore.getInstance(); + store.dispatch(setWallpaperColorAction(wallpaperColor)); + } }
diff --git a/ash/webui/personalization_app/resources/trusted/keyboard_backlight/keyboard_backlight_reducers.ts b/ash/webui/personalization_app/resources/trusted/keyboard_backlight/keyboard_backlight_reducers.ts index 177bc48..bb5b702 100644 --- a/ash/webui/personalization_app/resources/trusted/keyboard_backlight/keyboard_backlight_reducers.ts +++ b/ash/webui/personalization_app/resources/trusted/keyboard_backlight/keyboard_backlight_reducers.ts
@@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. +import {SkColor} from 'chrome://resources/mojo/skia/public/mojom/skcolor.mojom-webui.js'; + import {Actions} from '../personalization_actions.js'; import {BacklightColor} from '../personalization_app.mojom-webui.js'; import {ReducerFunction} from '../personalization_reducers.js'; @@ -21,9 +23,21 @@ } } +export function wallpaperColorReducer( + state: SkColor|null, action: Actions, _: PersonalizationState): SkColor| + null { + switch (action.name) { + case KeyboardBacklightActionName.SET_WALLPAPER_COLOR: + return action.wallpaperColor; + default: + return state; + } +} + export const keyboardBacklightReducers: { [K in keyof KeyboardBacklightState]: ReducerFunction<KeyboardBacklightState[K]> } = { backlightColor: backlightColorReducer, + wallpaperColor: wallpaperColorReducer, };
diff --git a/ash/webui/personalization_app/resources/trusted/keyboard_backlight/keyboard_backlight_state.ts b/ash/webui/personalization_app/resources/trusted/keyboard_backlight/keyboard_backlight_state.ts index 93b8c48..2d183455 100644 --- a/ash/webui/personalization_app/resources/trusted/keyboard_backlight/keyboard_backlight_state.ts +++ b/ash/webui/personalization_app/resources/trusted/keyboard_backlight/keyboard_backlight_state.ts
@@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. +import {SkColor} from 'chrome://resources/mojo/skia/public/mojom/skcolor.mojom-webui.js'; + import {BacklightColor} from '../personalization_app.mojom-webui.js'; /** @@ -9,10 +11,12 @@ */ export interface KeyboardBacklightState { backlightColor: BacklightColor|null; + wallpaperColor: SkColor|null; } export function emptyState(): KeyboardBacklightState { return { backlightColor: null, + wallpaperColor: null, }; }
diff --git a/ash/webui/personalization_app/resources/trusted/personalization_app.ts b/ash/webui/personalization_app/resources/trusted/personalization_app.ts index 7bfe71a6..90875d0 100644 --- a/ash/webui/personalization_app/resources/trusted/personalization_app.ts +++ b/ash/webui/personalization_app/resources/trusted/personalization_app.ts
@@ -65,7 +65,7 @@ export {TopicSourceList} from './ambient/topic_source_list_element.js'; export {AmbientZeroState} from './ambient/zero_state_element.js'; export {IFrameApi} from './iframe_api.js'; -export {KeyboardBacklightActionName, KeyboardBacklightActions, SetBacklightColorAction, setBacklightColorAction} from './keyboard_backlight/keyboard_backlight_actions.js'; +export {KeyboardBacklightActionName, KeyboardBacklightActions, SetBacklightColorAction, setBacklightColorAction, SetWallpaperColorAction, setWallpaperColorAction} from './keyboard_backlight/keyboard_backlight_actions.js'; export {KeyboardBacklight} from './keyboard_backlight/keyboard_backlight_element.js'; export {setKeyboardBacklightProviderForTesting} from './keyboard_backlight/keyboard_backlight_interface_provider.js'; export {KeyboardBacklightObserver} from './keyboard_backlight/keyboard_backlight_observer.js';
diff --git a/ash/webui/personalization_app/resources/trusted/user/avatar_camera_element.html b/ash/webui/personalization_app/resources/trusted/user/avatar_camera_element.html index 8313df1..317b1952 100644 --- a/ash/webui/personalization_app/resources/trusted/user/avatar_camera_element.html +++ b/ash/webui/personalization_app/resources/trusted/user/avatar_camera_element.html
@@ -95,7 +95,8 @@ width: 14px; } </style> -<cr-dialog id="dialog" show-close-button show-on-attach> +<cr-dialog id="dialog" show-close-button show-on-attach + close-text="$i18n{ariaLabelCloseCamera}"> <div slot="body"> <template is="dom-if" if="[[showLoading_(cameraStream_, previewBlobUrl_)]]"> <paper-spinner-lite id="cameraFeedSpinner" active></paper-spinner-lite> @@ -113,6 +114,7 @@ </svg> </template> <video id="webcamVideo" autoplay + aria-label="$i18n{ariaLabelWebcamVideo}" hidden$="[[!showCameraFeed_(cameraStream_, previewBlobUrl_)]]"></video> <template is="dom-if" if="[[previewBlobUrl_]]"> <img id="previewImg" src$="[[previewBlobUrl_]]">
diff --git a/ash/webui/personalization_app/resources/trusted/user/avatar_list_element.ts b/ash/webui/personalization_app/resources/trusted/user/avatar_list_element.ts index 033029c..0eb0ace4 100644 --- a/ash/webui/personalization_app/resources/trusted/user/avatar_list_element.ts +++ b/ash/webui/personalization_app/resources/trusted/user/avatar_list_element.ts
@@ -25,14 +25,37 @@ $: {avatarCamera: AvatarCamera}; } -type Option = { - id: string, - class: string, - imgSrc: string, - icon: string, - title: string, - defaultImageIndex?: number|null, -}; +enum OptionId { + LAST_EXTERNAL_IMAGE = 'lastExternalImage', + OPEN_CAMERA = 'openCamera', + OPEN_VIDEO = 'openVideo', + PROFILE_IMAGE = 'profileImage', + OPEN_FOLDER = 'openFolder', +} + +interface EnumeratedOption { + id: OptionId; + class: string; + imgSrc?: string; + icon: string; + title: string; +} + +interface DefaultOption { + id: string; + class: string; + imgSrc: string; + icon: string; + title: string; + defaultImageIndex: number; +} + +type Option = EnumeratedOption|DefaultOption; + +function isDefaultOption(option: Option): option is DefaultOption { + return option && + typeof (option as DefaultOption).defaultImageIndex === 'number'; +} export class AvatarList extends WithPersonalizationStore { static get is() { @@ -51,7 +74,10 @@ image_: Object, - lastExternalUserImageUrl_: Object, + lastExternalUserImageUrl_: { + type: Object, + observer: 'onLastExternalUserImageUrlChanged_', + }, /** The presence of a device camera. */ isCameraPresent_: { @@ -71,13 +97,17 @@ */ options_: { type: Array, - computed: - 'computeOptions_(isCameraPresent_, profileImage_, lastExternalUserImageUrl_, defaultUserImages_)', - observer: 'onOptionsChanged_', + value: [], }, }; } + static get observers() { + return [ + 'updateOptions_(isCameraPresent_, profileImage_, lastExternalUserImageUrl_, defaultUserImages_)', + ]; + } + private defaultUserImages_: Array<DefaultUserImage>|null; private profileImage_: Url|null; private isCameraPresent_: boolean; @@ -101,41 +131,39 @@ fetchDefaultUserImages(getUserProvider(), this.getStore()); } - /** Invoked to compute |options_|. */ - private computeOptions_( + /** Invoked to update |options_|. */ + private updateOptions_( isCameraPresent: AvatarList['isCameraPresent_'], profileImage: AvatarList['profileImage_'], lastExternalUserImageUrl: AvatarList['lastExternalUserImageUrl_'], defaultUserImages: AvatarList['defaultUserImages_']) { - const options = [] as Option[]; + const options: Option[] = []; if (isCameraPresent) { // Add camera and video options. options.push({ - id: 'openCamera', + id: OptionId.OPEN_CAMERA, class: 'avatar-button-container', imgSrc: '', icon: 'personalization:camera', title: this.i18n('takeWebcamPhoto'), }); options.push({ - id: 'openVideo', + id: OptionId.OPEN_VIDEO, class: 'avatar-button-container', - imgSrc: '', icon: 'personalization:loop', title: this.i18n('takeWebcamVideo'), }); } // Add open folder option. options.push({ - id: 'openFolder', + id: OptionId.OPEN_FOLDER, class: 'avatar-button-container', - imgSrc: '', icon: 'personalization:folder', title: this.i18n('chooseAFile'), }); if (profileImage) { options.push({ - id: 'profileImage', + id: OptionId.PROFILE_IMAGE, class: 'image-container', imgSrc: profileImage.url, icon: 'personalization:checkmark', @@ -144,7 +172,7 @@ } if (lastExternalUserImageUrl) { options.push({ - id: 'lastExternalImage', + id: OptionId.LAST_EXTERNAL_IMAGE, class: 'image-container', imgSrc: lastExternalUserImageUrl.url, icon: 'personalization:checkmark', @@ -163,15 +191,29 @@ }); }); } - return options; - } - - /** Invoked on changes to |options_|. */ - private onOptionsChanged_(options: AvatarList['options_']) { this.updateList( /*propertyPath=*/ 'options_', - /*identityGetter=*/ (option: Option) => option.id, - /*newList=*/ options, /*identityBasedUpdate=*/ true); + /*identityGetter=*/ + (option: Option) => { + switch (option.id) { + // LAST_EXTERNAL_IMAGE needs to use imgSrc instead of id. Otherwise + // iron-list will not update properly when LAST_EXTERNAL_IMAGE + // changes, i.e. when user selects a new file from disk. + case OptionId.LAST_EXTERNAL_IMAGE: + return option.imgSrc!; + default: + return option.id; + } + }, + /*newList=*/ options, + /*identityBasedUpdate=*/ true, + ); + } + + private onLastExternalUserImageUrlChanged_(_: Url|null, old: Url|null) { + if (old && old.url && old.url.startsWith('blob:')) { + URL.revokeObjectURL(old.url); + } } private onOptionSelected_(e: Event) { @@ -181,19 +223,19 @@ const divElement = e.currentTarget as HTMLDivElement; const id = divElement.id; switch (id) { - case 'openCamera': + case OptionId.OPEN_CAMERA: this.openCamera_(e); break; - case 'openVideo': + case OptionId.OPEN_VIDEO: this.openVideo_(e); break; - case 'openFolder': + case OptionId.OPEN_FOLDER: this.onSelectImageFromDisk_(e); break; - case 'profileImage': + case OptionId.PROFILE_IMAGE: this.onSelectProfileImage_(e); break; - case 'lastExternalImage': + case OptionId.LAST_EXTERNAL_IMAGE: this.onSelectLastExternalUserImage_(e); break; default: @@ -282,16 +324,17 @@ return 'false'; } switch (option.id) { - case 'openCamera': - case 'openVideo': - case 'openFolder': + case OptionId.OPEN_CAMERA: + case OptionId.OPEN_VIDEO: + case OptionId.OPEN_FOLDER: return 'false'; - case 'profileImage': + case OptionId.PROFILE_IMAGE: return (!!image && !!image.profileImage).toString(); - case 'lastExternalImage': + case OptionId.LAST_EXTERNAL_IMAGE: return (!!image && !!image.externalImage).toString(); default: // Handle default user image. + assert(isDefaultOption(option)); return (!!image && !!image.defaultImage && image.defaultImage.index === option.defaultImageIndex) .toString();
diff --git a/ash/webui/personalization_app/resources/trusted/user/user_preview_element.ts b/ash/webui/personalization_app/resources/trusted/user/user_preview_element.ts index a0b7aa87..a23399b 100644 --- a/ash/webui/personalization_app/resources/trusted/user/user_preview_element.ts +++ b/ash/webui/personalization_app/resources/trusted/user/user_preview_element.ts
@@ -100,11 +100,6 @@ if (value && old) { this.dispatchEvent(new AvatarChangedEvent()); } - if (old && old.url.startsWith('blob:')) { - // Revoke old object urls to clear memory. This is safe to call multiple - // times. - URL.revokeObjectURL(old.url); - } } private shouldShowMainPageView_(path: string, isEnterpriseManaged: boolean):
diff --git a/ash/webui/shimless_rma/backend/shimless_rma_service_unittest.cc b/ash/webui/shimless_rma/backend/shimless_rma_service_unittest.cc index eed10153..4044193 100644 --- a/ash/webui/shimless_rma/backend/shimless_rma_service_unittest.cc +++ b/ash/webui/shimless_rma/backend/shimless_rma_service_unittest.cc
@@ -309,34 +309,34 @@ base::RunLoop run_loop; - // Initialize current state, can_cancel=true, can_go_back=false + // Initialize current state, can_exit=true, can_go_back=false shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); - EXPECT_EQ(can_cancel, true); + EXPECT_EQ(can_exit, true); EXPECT_EQ(can_go_back, false); })); run_loop.RunUntilIdle(); - // Next state, can_cancel=false, can_go_back=true + // Next state, can_exit=false, can_go_back=true shimless_rma_provider_->SetSameOwner(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kSelectComponents); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); - EXPECT_EQ(can_cancel, false); + EXPECT_EQ(can_exit, false); EXPECT_EQ(can_go_back, true); })); run_loop.RunUntilIdle(); - // Previous state, can_cancel=true, can_go_back=false + // Previous state, can_exit=true, can_go_back=false shimless_rma_provider_->TransitionPreviousState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); - EXPECT_EQ(can_cancel, true); + EXPECT_EQ(can_exit, true); EXPECT_EQ(can_go_back, false); run_loop.Quit(); })); @@ -356,7 +356,7 @@ // Initialize current state shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kWelcomeScreen); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -365,7 +365,7 @@ // With a WiFi network it should redirect to kUpdateOs shimless_rma_provider_->BeginFinalization(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kUpdateOs); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -385,7 +385,7 @@ // Initialize current state shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kWelcomeScreen); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -397,7 +397,7 @@ // No network should prompt select network page shimless_rma_provider_->BeginFinalization(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kConfigureNetwork); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -417,7 +417,7 @@ // Initialize current state shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kWelcomeScreen); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -426,7 +426,7 @@ // No network should prompt select network page shimless_rma_provider_->BeginFinalization(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kConfigureNetwork); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -436,7 +436,7 @@ // With a WiFi network it should redirect to kUpdateOs shimless_rma_provider_->NetworkSelectionComplete(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kUpdateOs); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -456,7 +456,7 @@ // Initialize current state shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kWelcomeScreen); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -465,7 +465,7 @@ // No network should prompt select network page shimless_rma_provider_->BeginFinalization(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kConfigureNetwork); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -474,7 +474,7 @@ // With no network it should redirect to next rmad state shimless_rma_provider_->NetworkSelectionComplete(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kSelectComponents); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -491,7 +491,7 @@ base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kWelcomeScreen); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -500,7 +500,7 @@ // Sets it to `kConfigureNetwork` state. shimless_rma_provider_->BeginFinalization(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kConfigureNetwork); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -563,7 +563,7 @@ base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kWelcomeScreen); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -572,7 +572,7 @@ // Sets it to `kConfigureNetwork` state. shimless_rma_provider_->BeginFinalization(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kConfigureNetwork); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -618,7 +618,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RMAD_ERROR_OK); @@ -632,7 +632,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kUnknown); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_RMA_NOT_REQUIRED); @@ -650,21 +650,21 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); })); run_loop.RunUntilIdle(); shimless_rma_provider_->SetSameOwner(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kSelectComponents); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); })); run_loop.RunUntilIdle(); shimless_rma_provider_->TransitionPreviousState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -683,7 +683,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->TransitionPreviousState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_TRANSITION_FAILED); @@ -698,14 +698,14 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); })); run_loop.RunUntilIdle(); shimless_rma_provider_->TransitionPreviousState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_TRANSITION_FAILED); @@ -714,7 +714,7 @@ run_loop.Run(); } -TEST_F(ShimlessRmaServiceTest, CanCancelRma) { +TEST_F(ShimlessRmaServiceTest, CanExitRma) { const std::vector<rmad::GetStateReply> fake_states = {CreateStateReply( rmad::RmadState::kDeviceDestination, rmad::RMAD_ERROR_OK)}; fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); @@ -756,7 +756,7 @@ }); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -764,7 +764,7 @@ run_loop.RunUntilIdle(); shimless_rma_provider_->SetSameOwner(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kRestock); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -779,7 +779,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kRestock); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -787,7 +787,7 @@ run_loop.RunUntilIdle(); shimless_rma_provider_->SetSameOwner(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kRestock); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_REQUEST_INVALID); @@ -810,7 +810,7 @@ }); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -818,7 +818,7 @@ run_loop.RunUntilIdle(); shimless_rma_provider_->SetDifferentOwner(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kRestock); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -840,7 +840,7 @@ }); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseWipeDevice); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -850,7 +850,7 @@ const bool expected_wipe_device = true; shimless_rma_provider_->SetWipeDevice( expected_wipe_device, - base::BindLambdaForTesting([&](mojom::State state, bool can_cancel, + base::BindLambdaForTesting([&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); @@ -866,14 +866,14 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kRestock); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); })); run_loop.RunUntilIdle(); shimless_rma_provider_->SetDifferentOwner(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kRestock); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_REQUEST_INVALID); @@ -897,7 +897,7 @@ }); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseWriteProtectDisableMethod); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -905,7 +905,7 @@ run_loop.RunUntilIdle(); shimless_rma_provider_->ChooseManuallyDisableWriteProtect( - base::BindLambdaForTesting([&](mojom::State state, bool can_cancel, + base::BindLambdaForTesting([&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); @@ -922,7 +922,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -930,7 +930,7 @@ run_loop.RunUntilIdle(); shimless_rma_provider_->ChooseManuallyDisableWriteProtect( - base::BindLambdaForTesting([&](mojom::State state, bool can_cancel, + base::BindLambdaForTesting([&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); @@ -954,7 +954,7 @@ }); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseWriteProtectDisableMethod); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -962,7 +962,7 @@ run_loop.RunUntilIdle(); shimless_rma_provider_->ChooseRsuDisableWriteProtect( - base::BindLambdaForTesting([&](mojom::State state, bool can_cancel, + base::BindLambdaForTesting([&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); @@ -978,7 +978,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -986,7 +986,7 @@ run_loop.RunUntilIdle(); shimless_rma_provider_->ChooseRsuDisableWriteProtect( - base::BindLambdaForTesting([&](mojom::State state, bool can_cancel, + base::BindLambdaForTesting([&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); @@ -1007,7 +1007,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kEnterRSUWPDisableCode); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -1033,7 +1033,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kEnterRSUWPDisableCode); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -1059,7 +1059,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kEnterRSUWPDisableCode); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -1118,7 +1118,7 @@ }); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kEnterRSUWPDisableCode); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -1127,7 +1127,7 @@ shimless_rma_provider_->SetRsuDisableWriteProtectCode( "test RSU unlock code", - base::BindLambdaForTesting([&](mojom::State state, bool can_cancel, + base::BindLambdaForTesting([&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); @@ -1144,7 +1144,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -1153,7 +1153,7 @@ shimless_rma_provider_->SetRsuDisableWriteProtectCode( "test RSU unlock code", - base::BindLambdaForTesting([&](mojom::State state, bool can_cancel, + base::BindLambdaForTesting([&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); @@ -1172,7 +1172,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kWaitForManualWPDisable); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -1180,7 +1180,7 @@ run_loop.RunUntilIdle(); shimless_rma_provider_->WriteProtectManuallyDisabled( - base::BindLambdaForTesting([&](mojom::State state, bool can_cancel, + base::BindLambdaForTesting([&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); @@ -1197,7 +1197,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -1205,7 +1205,7 @@ run_loop.RunUntilIdle(); shimless_rma_provider_->WriteProtectManuallyDisabled( - base::BindLambdaForTesting([&](mojom::State state, bool can_cancel, + base::BindLambdaForTesting([&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); @@ -1224,7 +1224,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kWPDisableComplete); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -1232,7 +1232,7 @@ run_loop.RunUntilIdle(); shimless_rma_provider_->ConfirmManualWpDisableComplete( - base::BindLambdaForTesting([&](mojom::State state, bool can_cancel, + base::BindLambdaForTesting([&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); @@ -1249,7 +1249,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -1257,7 +1257,7 @@ run_loop.RunUntilIdle(); shimless_rma_provider_->ConfirmManualWpDisableComplete( - base::BindLambdaForTesting([&](mojom::State state, bool can_cancel, + base::BindLambdaForTesting([&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); @@ -1279,7 +1279,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kWPDisableComplete); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -1322,7 +1322,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kSelectComponents); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -1356,7 +1356,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -1424,7 +1424,7 @@ }); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kSelectComponents); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -1445,7 +1445,7 @@ shimless_rma_provider_->SetComponentList( std::move(components), - base::BindLambdaForTesting([&](mojom::State state, bool can_cancel, + base::BindLambdaForTesting([&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); @@ -1461,7 +1461,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -1476,7 +1476,7 @@ shimless_rma_provider_->SetComponentList( std::move(components), - base::BindLambdaForTesting([&](mojom::State state, bool can_cancel, + base::BindLambdaForTesting([&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); @@ -1518,7 +1518,7 @@ }); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kSelectComponents); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -1526,7 +1526,7 @@ run_loop.RunUntilIdle(); shimless_rma_provider_->ReworkMainboard(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -1541,7 +1541,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -1549,7 +1549,7 @@ run_loop.RunUntilIdle(); shimless_rma_provider_->ReworkMainboard(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_REQUEST_INVALID); @@ -1572,7 +1572,7 @@ }); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kUpdateRoFirmware); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -1580,7 +1580,7 @@ run_loop.RunUntilIdle(); shimless_rma_provider_->RoFirmwareUpdateComplete(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -1595,7 +1595,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -1603,7 +1603,7 @@ run_loop.RunUntilIdle(); shimless_rma_provider_->RoFirmwareUpdateComplete(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_REQUEST_INVALID); @@ -1626,7 +1626,7 @@ }); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kRestock); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -1634,7 +1634,7 @@ run_loop.RunUntilIdle(); shimless_rma_provider_->ShutdownForRestock(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -1649,7 +1649,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -1657,7 +1657,7 @@ run_loop.RunUntilIdle(); shimless_rma_provider_->ShutdownForRestock(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_REQUEST_INVALID); @@ -1680,7 +1680,7 @@ }); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kRestock); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -1688,7 +1688,7 @@ run_loop.RunUntilIdle(); shimless_rma_provider_->ContinueFinalizationAfterRestock( - base::BindLambdaForTesting([&](mojom::State state, bool can_cancel, + base::BindLambdaForTesting([&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); @@ -1705,7 +1705,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -1713,7 +1713,7 @@ run_loop.RunUntilIdle(); shimless_rma_provider_->ContinueFinalizationAfterRestock( - base::BindLambdaForTesting([&](mojom::State state, bool can_cancel, + base::BindLambdaForTesting([&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); @@ -1739,7 +1739,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kUpdateDeviceInformation); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -1760,7 +1760,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -1791,7 +1791,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kUpdateDeviceInformation); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -1814,7 +1814,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -1848,7 +1848,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kUpdateDeviceInformation); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -1872,7 +1872,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -1903,7 +1903,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kUpdateDeviceInformation); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -1926,7 +1926,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -1957,7 +1957,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kUpdateDeviceInformation); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -1978,7 +1978,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -2009,7 +2009,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kUpdateDeviceInformation); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -2030,7 +2030,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -2061,7 +2061,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kUpdateDeviceInformation); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -2082,7 +2082,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -2113,7 +2113,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kUpdateDeviceInformation); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -2134,7 +2134,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -2166,7 +2166,7 @@ }); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kUpdateDeviceInformation); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -2175,7 +2175,7 @@ shimless_rma_provider_->SetDeviceInformation( "serial number", 1, 2, 3, "123-456-789", - base::BindLambdaForTesting([&](mojom::State state, bool can_cancel, + base::BindLambdaForTesting([&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); @@ -2191,7 +2191,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -2200,7 +2200,7 @@ shimless_rma_provider_->SetDeviceInformation( "serial number", 1, 2, 3, "123-456-789", - base::BindLambdaForTesting([&](mojom::State state, bool can_cancel, + base::BindLambdaForTesting([&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); @@ -2232,7 +2232,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kCheckCalibration); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -2254,7 +2254,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -2285,7 +2285,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kSetupCalibration); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -2311,7 +2311,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -2353,7 +2353,7 @@ }); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kCheckCalibration); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -2372,7 +2372,7 @@ shimless_rma_provider_->StartCalibration( std::move(components), - base::BindLambdaForTesting([&](mojom::State state, bool can_cancel, + base::BindLambdaForTesting([&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); @@ -2390,7 +2390,7 @@ base::BindRepeating([](const rmad::RmadState& state) { NOTREACHED(); }); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -2407,7 +2407,7 @@ shimless_rma_provider_->StartCalibration( std::move(components), - base::BindLambdaForTesting([&](mojom::State state, bool can_cancel, + base::BindLambdaForTesting([&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); @@ -2425,7 +2425,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kSetupCalibration); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -2433,7 +2433,7 @@ run_loop.RunUntilIdle(); shimless_rma_provider_->RunCalibrationStep(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -2448,7 +2448,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -2456,7 +2456,7 @@ run_loop.RunUntilIdle(); shimless_rma_provider_->RunCalibrationStep(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_REQUEST_INVALID); @@ -2472,7 +2472,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kRunCalibration); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -2480,7 +2480,7 @@ run_loop.RunUntilIdle(); shimless_rma_provider_->ContinueCalibration(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -2495,7 +2495,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -2503,7 +2503,7 @@ run_loop.RunUntilIdle(); shimless_rma_provider_->ContinueCalibration(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_REQUEST_INVALID); @@ -2520,7 +2520,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kRunCalibration); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -2528,7 +2528,7 @@ run_loop.RunUntilIdle(); shimless_rma_provider_->CalibrationComplete(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -2543,7 +2543,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -2551,7 +2551,7 @@ run_loop.RunUntilIdle(); shimless_rma_provider_->CalibrationComplete(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_REQUEST_INVALID); @@ -2574,7 +2574,7 @@ }); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kProvisionDevice); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -2582,7 +2582,7 @@ run_loop.RunUntilIdle(); shimless_rma_provider_->ProvisioningComplete(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -2597,7 +2597,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -2605,7 +2605,7 @@ run_loop.RunUntilIdle(); shimless_rma_provider_->ProvisioningComplete(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_REQUEST_INVALID); @@ -2622,7 +2622,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kProvisionDevice); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -2630,7 +2630,7 @@ run_loop.RunUntilIdle(); shimless_rma_provider_->RetryProvisioning(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -2647,7 +2647,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kFinalize); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -2655,7 +2655,7 @@ run_loop.RunUntilIdle(); shimless_rma_provider_->FinalizationComplete(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -2678,7 +2678,7 @@ }); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kFinalize); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -2686,7 +2686,7 @@ run_loop.RunUntilIdle(); shimless_rma_provider_->RetryFinalization(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -2701,7 +2701,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -2709,7 +2709,7 @@ run_loop.RunUntilIdle(); shimless_rma_provider_->FinalizationComplete(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_REQUEST_INVALID); @@ -2726,7 +2726,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kWaitForManualWPEnable); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -2734,7 +2734,7 @@ run_loop.RunUntilIdle(); shimless_rma_provider_->WriteProtectManuallyEnabled( - base::BindLambdaForTesting([&](mojom::State state, bool can_cancel, + base::BindLambdaForTesting([&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); @@ -2750,7 +2750,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -2758,7 +2758,7 @@ run_loop.RunUntilIdle(); shimless_rma_provider_->WriteProtectManuallyEnabled( - base::BindLambdaForTesting([&](mojom::State state, bool can_cancel, + base::BindLambdaForTesting([&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); @@ -2776,7 +2776,7 @@ fake_rmad_client_()->SetGetLogReply(expected_log, rmad::RMAD_ERROR_OK); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kRepairComplete); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -2797,7 +2797,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -2824,7 +2824,7 @@ rmad::RMAD_ERROR_OK); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kRepairComplete); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -2845,7 +2845,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -2870,7 +2870,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kRepairComplete); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -2899,7 +2899,7 @@ }); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kRepairComplete); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -2908,7 +2908,7 @@ shimless_rma_provider_->EndRma( rmad::RepairCompleteState::RMAD_REPAIR_COMPLETE_SHUTDOWN, - base::BindLambdaForTesting([&](mojom::State state, bool can_cancel, + base::BindLambdaForTesting([&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); @@ -2924,7 +2924,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -2933,7 +2933,7 @@ shimless_rma_provider_->EndRma( rmad::RepairCompleteState::RMAD_REPAIR_COMPLETE_SHUTDOWN, - base::BindLambdaForTesting([&](mojom::State state, bool can_cancel, + base::BindLambdaForTesting([&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kChooseDestination); @@ -2949,7 +2949,7 @@ fake_rmad_client_()->SetFakeStateReplies(std::move(fake_states)); base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kRepairComplete); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -3067,7 +3067,7 @@ base::RunLoop run_loop; shimless_rma_provider_->GetCurrentState(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kSetupCalibration); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK); @@ -3089,7 +3089,7 @@ EXPECT_EQ(fake_observer.component_observations[0].progress(), 0.25); shimless_rma_provider_->RunCalibrationStep(base::BindLambdaForTesting( - [&](mojom::State state, bool can_cancel, bool can_go_back, + [&](mojom::State state, bool can_exit, bool can_go_back, rmad::RmadErrorCode error) { EXPECT_EQ(state, mojom::State::kRunCalibration); EXPECT_EQ(error, rmad::RmadErrorCode::RMAD_ERROR_OK);
diff --git a/ash/webui/shimless_rma/mojom/shimless_rma.mojom b/ash/webui/shimless_rma/mojom/shimless_rma.mojom index 1aa625d6..0de30b8 100644 --- a/ash/webui/shimless_rma/mojom/shimless_rma.mojom +++ b/ash/webui/shimless_rma/mojom/shimless_rma.mojom
@@ -479,12 +479,12 @@ // Used on application start to determine the location in the RMA flow. // Due to reboots it may not always be the welcome screen. GetCurrentState() - => (State state, bool can_cancel, bool can_go_back, + => (State state, bool can_exit, bool can_go_back, RmadErrorCode error); // Attempt to roll back to the previous RMA state. // Returns the updated state to display and an error code. TransitionPreviousState() - => (State state, bool can_cancel, bool can_go_back, + => (State state, bool can_exit, bool can_go_back, RmadErrorCode error); // Attempts to abort the RMA. @@ -496,7 +496,7 @@ // // User has confirmed they wish to finalize RMA. BeginFinalization() - => (State state, bool can_cancel, bool can_go_back, + => (State state, bool can_exit, bool can_go_back, RmadErrorCode error); /////////////////////////////////////// @@ -508,7 +508,7 @@ // Called when next is clicked after a network is successfully connected or // the user skips connecting to a network. NetworkSelectionComplete() - => (State state, bool can_cancel, bool can_go_back, + => (State state, bool can_exit, bool can_go_back, RmadErrorCode error); /////////////////////////////////////// @@ -522,7 +522,7 @@ UpdateOs() => (bool update_started); // Skips OS update. UpdateOsSkipped() - => (State state, bool can_cancel, bool can_go_back, + => (State state, bool can_exit, bool can_go_back, RmadErrorCode error); /////////////////////////////////////// @@ -531,12 +531,12 @@ // Set the RMA state for the device to be kept by the current owner. // Returns the next state to display and an error code. SetSameOwner() - => (State state, bool can_cancel, bool can_go_back, + => (State state, bool can_exit, bool can_go_back, RmadErrorCode error); // Set the RMA state for the device to be given to a different owner. // Returns the next state to display and an error code. SetDifferentOwner() - => (State state, bool can_cancel, bool can_go_back, + => (State state, bool can_exit, bool can_go_back, RmadErrorCode error); /////////////////////////////////////// @@ -544,7 +544,7 @@ // // Set the RMA state to wipe or preserve the device data on RMA completion. SetWipeDevice(bool should_wipe_device) - => (State state, bool can_cancel, bool can_go_back, + => (State state, bool can_exit, bool can_go_back, RmadErrorCode error); /////////////////////////////////////// @@ -556,14 +556,14 @@ // TODO(crbug.com/1218175): Rename SetManuallyDisableWriteProtect for // consistency with other methods. ChooseManuallyDisableWriteProtect() - => (State state, bool can_cancel, bool can_go_back, + => (State state, bool can_exit, bool can_go_back, RmadErrorCode error); // Choose to disable HWWP using the RSU code method. // Returns the next state to display and an error code. // TODO(crbug.com/1218175): Rename SetRsuDisableWriteProtect for // consistency with other methods. ChooseRsuDisableWriteProtect() - => (State state, bool can_cancel, bool can_go_back, + => (State state, bool can_exit, bool can_go_back, RmadErrorCode error); /////////////////////////////////////// @@ -581,7 +581,7 @@ // Attempt to disable HWWP using a RSU code. // Returns the next state to display and an error code. SetRsuDisableWriteProtectCode(string code) - => (State state, bool can_cancel, bool can_go_back, + => (State state, bool can_exit, bool can_go_back, RmadErrorCode error); /////////////////////////////////////// @@ -590,7 +590,7 @@ // Transition to next state after manual write protect disabled signal has // been received. WriteProtectManuallyDisabled() - => (State state, bool can_cancel, bool can_go_back, + => (State state, bool can_exit, bool can_go_back, RmadErrorCode error); // Returns a display string and QR Code image representing the URL that takes // users to the manufacturer specific instructions page for manually disabling @@ -610,7 +610,7 @@ // User acknowledges manual HWWP disable is complete and transitions to next // state. ConfirmManualWpDisableComplete() - => (State state, bool can_cancel, bool can_go_back, + => (State state, bool can_exit, bool can_go_back, RmadErrorCode error); /////////////////////////////////////// @@ -624,11 +624,11 @@ // This list only needs to contain the components set as repaired (any others // included will be ignored by rmad service). SetComponentList(array<Component> components) - => (State state, bool can_cancel, bool can_go_back, + => (State state, bool can_exit, bool can_go_back, RmadErrorCode error); // Go to rework flow. ReworkMainboard() - => (State state, bool can_cancel, bool can_go_back, + => (State state, bool can_exit, bool can_go_back, RmadErrorCode error); /////////////////////////////////////// @@ -636,7 +636,7 @@ // // Continue after firmware reimaging completes. RoFirmwareUpdateComplete() - => (State state, bool can_cancel, bool can_go_back, + => (State state, bool can_exit, bool can_go_back, RmadErrorCode error); @@ -646,11 +646,11 @@ // Shutdown the device so mainboard can be restocked. // Note: This will only return a result if there is an error. ShutdownForRestock() - => (State state, bool can_cancel, bool can_go_back, + => (State state, bool can_exit, bool can_go_back, RmadErrorCode error); // Continue RMA finalization after mainboard is used in another device. ContinueFinalizationAfterRestock() - => (State state, bool can_cancel, bool can_go_back, + => (State state, bool can_exit, bool can_go_back, RmadErrorCode error); /////////////////////////////////////// @@ -677,7 +677,7 @@ SetDeviceInformation( string serial_number, int32 region_index, int32 sku_index, int32 white_label_index, string dram_part_number) - => (State state, bool can_cancel, bool can_go_back, + => (State state, bool can_exit, bool can_go_back, RmadErrorCode error); /////////////////////////////////////// @@ -703,31 +703,31 @@ // Next state will be kSetupCalibration if setup is required, or // kRunCalibration if not. StartCalibration(array<CalibrationComponentStatus> components) - => (State state, bool can_cancel, bool can_go_back, + => (State state, bool can_exit, bool can_go_back, RmadErrorCode error); // Request transition from kSetupCalibration to run this calibration step. RunCalibrationStep() - => (State state, bool can_cancel, bool can_go_back, + => (State state, bool can_exit, bool can_go_back, RmadErrorCode error); // Request transition from kRunCalibration to the next setup state. ContinueCalibration() - => (State state, bool can_cancel, bool can_go_back, + => (State state, bool can_exit, bool can_go_back, RmadErrorCode error); // Request transition from kRunCalibratoin to the next RMA state. // This can only be called after kCalibrationOverallComplete has been // observed. CalibrationComplete() - => (State state, bool can_cancel, bool can_go_back, + => (State state, bool can_exit, bool can_go_back, RmadErrorCode error); /////////////////////////////////////// // Methods for kProvisionDevice state // // Retries provisioning after a blocking failure. - RetryProvisioning() => (State state, bool can_cancel, bool can_go_back, + RetryProvisioning() => (State state, bool can_exit, bool can_go_back, RmadErrorCode error); ProvisioningComplete() - => (State state, bool can_cancel, bool can_go_back, + => (State state, bool can_exit, bool can_go_back, RmadErrorCode error); /////////////////////////////////////// @@ -735,10 +735,10 @@ // // Retries provisioning after a failure. RetryFinalization() - => (State state, bool can_cancel, bool can_go_back, + => (State state, bool can_exit, bool can_go_back, RmadErrorCode error); FinalizationComplete() - => (State state, bool can_cancel, bool can_go_back, + => (State state, bool can_exit, bool can_go_back, RmadErrorCode error); /////////////////////////////////////// @@ -747,7 +747,7 @@ // Transition to next state after manual write protect enabled signal has been // received. WriteProtectManuallyEnabled() - => (State state, bool can_cancel, bool can_go_back, + => (State state, bool can_exit, bool can_go_back, RmadErrorCode error); /////////////////////////////////////// @@ -766,7 +766,7 @@ // Complete RMA using the `shutdown_method`. // Returns an error indicating success or a failure. EndRma(ShutdownMethod shutdown_method) - => (State state, bool can_cancel, bool can_go_back, + => (State state, bool can_exit, bool can_go_back, RmadErrorCode error); /////////////////////////////////////// @@ -775,9 +775,9 @@ // Currently the only critical error is when error is kRmaNotRequired // (state is kUnknown). // - // Attempt to cancel RMA and restart Chrome without checking RMA. + // Attempt to exit RMA and restart Chrome without checking RMA. CriticalErrorExitToLogin() => (RmadErrorCode error); - // Attempt to cancel RMA and reboot the device. + // Attempt to exit RMA and reboot the device. CriticalErrorReboot() => (RmadErrorCode error); ///////////////////////////////////////
diff --git a/ash/webui/shimless_rma/resources/fake_data.js b/ash/webui/shimless_rma/resources/fake_data.js index bc5fef5..1b78205 100644 --- a/ash/webui/shimless_rma/resources/fake_data.js +++ b/ash/webui/shimless_rma/resources/fake_data.js
@@ -10,127 +10,127 @@ export const fakeStates = [ { state: State.kWelcomeScreen, - canCancel: true, + canExit: true, canGoBack: false, error: RmadErrorCode.kOk }, { state: State.kConfigureNetwork, - canCancel: true, + canExit: true, canGoBack: true, error: RmadErrorCode.kOk }, { state: State.kUpdateOs, - canCancel: true, + canExit: true, canGoBack: true, error: RmadErrorCode.kOk }, { state: State.kSelectComponents, - canCancel: true, + canExit: true, canGoBack: true, error: RmadErrorCode.kOk }, { state: State.kChooseDestination, - canCancel: true, + canExit: true, canGoBack: true, error: RmadErrorCode.kOk }, { state: State.kChooseWipeDevice, - canCancel: true, + canExit: true, canGoBack: true, error: RmadErrorCode.kOk }, { state: State.kChooseWriteProtectDisableMethod, - canCancel: true, + canExit: true, canGoBack: true, error: RmadErrorCode.kOk }, { state: State.kEnterRSUWPDisableCode, - canCancel: true, + canExit: true, canGoBack: true, error: RmadErrorCode.kOk }, { state: State.kWaitForManualWPDisable, - canCancel: true, + canExit: true, canGoBack: true, error: RmadErrorCode.kOk }, { state: State.kWPDisableComplete, - canCancel: true, + canExit: true, canGoBack: true, error: RmadErrorCode.kOk }, { state: State.kUpdateRoFirmware, - canCancel: true, + canExit: true, canGoBack: true, error: RmadErrorCode.kOk }, { state: State.kUpdateDeviceInformation, - canCancel: true, + canExit: true, canGoBack: true, error: RmadErrorCode.kOk }, { state: State.kRestock, - canCancel: true, + canExit: true, canGoBack: true, error: RmadErrorCode.kOk }, { state: State.kCheckCalibration, - canCancel: true, + canExit: true, canGoBack: true, error: RmadErrorCode.kOk }, { state: State.kSetupCalibration, - canCancel: true, + canExit: true, canGoBack: true, error: RmadErrorCode.kOk }, { state: State.kRunCalibration, - canCancel: true, + canExit: true, canGoBack: true, error: RmadErrorCode.kOk }, { state: State.kProvisionDevice, - canCancel: true, + canExit: true, canGoBack: true, error: RmadErrorCode.kOk }, { state: State.kWaitForManualWPEnable, - canCancel: true, + canExit: true, canGoBack: true, error: RmadErrorCode.kOk }, { state: State.kFinalize, - canCancel: true, + canExit: true, canGoBack: true, error: RmadErrorCode.kOk }, { state: State.kRepairComplete, - canCancel: true, + canExit: true, canGoBack: true, error: RmadErrorCode.kOk }, { state: State.kUnknown, - canCancel: false, + canExit: false, canGoBack: false, error: RmadErrorCode.kOk },
diff --git a/ash/webui/shimless_rma/resources/fake_shimless_rma_service.js b/ash/webui/shimless_rma/resources/fake_shimless_rma_service.js index 32b9d96..5e32cac 100644 --- a/ash/webui/shimless_rma/resources/fake_shimless_rma_service.js +++ b/ash/webui/shimless_rma/resources/fake_shimless_rma_service.js
@@ -134,7 +134,7 @@ assert(this.stateIndex_ < this.states_.length); const state = this.states_[this.stateIndex_]; this.setFakeCurrentState_( - state.state, state.canCancel, state.canGoBack, state.error); + state.state, state.canExit, state.canGoBack, state.error); } return this.methods_.resolveMethodWithDelay( 'getCurrentState', this.resolveMethodDelayMs_); @@ -155,13 +155,13 @@ assert(this.stateIndex_ < this.states_.length); const state = this.states_[this.stateIndex_]; this.setFakePrevState_( - state.state, state.canCancel, state.canGoBack, + state.state, state.canExit, state.canGoBack, RmadErrorCode.kTransitionFailed); } else { this.stateIndex_--; const state = this.states_[this.stateIndex_]; this.setFakePrevState_( - state.state, state.canCancel, state.canGoBack, state.error); + state.state, state.canExit, state.canGoBack, state.error); } return this.methods_.resolveMethodWithDelay( 'transitionPreviousState', this.resolveMethodDelayMs_); @@ -1242,7 +1242,7 @@ this.methods_.register('abortRma'); - this.methods_.register('canCancel'); + this.methods_.register('canExit'); this.methods_.register('canGoBack'); this.methods_.register('beginFinalization'); @@ -1359,13 +1359,13 @@ assert(this.stateIndex_ < this.states_.length); const state = this.states_[this.stateIndex_]; this.setFakeStateForMethod_( - method, state.state, state.canCancel, state.canGoBack, + method, state.state, state.canExit, state.canGoBack, RmadErrorCode.kTransitionFailed); } else if (this.states_[this.stateIndex_].state !== expectedState) { // Error: Called in wrong state. const state = this.states_[this.stateIndex_]; this.setFakeStateForMethod_( - method, state.state, state.canCancel, state.canGoBack, + method, state.state, state.canExit, state.canGoBack, RmadErrorCode.kRequestInvalid); } else { // Success. @@ -1377,7 +1377,7 @@ } const state = this.states_[this.stateIndex_]; this.setFakeStateForMethod_( - method, state.state, state.canCancel, state.canGoBack, state.error); + method, state.state, state.canExit, state.canGoBack, state.error); } return this.methods_.resolveMethodWithDelay( method, this.resolveMethodDelayMs_); @@ -1386,28 +1386,28 @@ /** * Sets the value that will be returned when calling getCurrent(). * @param {!State} state - * @param {boolean} canCancel, + * @param {boolean} canExit, * @param {boolean} canGoBack, * @param {!RmadErrorCode} error * @private */ - setFakeCurrentState_(state, canCancel, canGoBack, error) { + setFakeCurrentState_(state, canExit, canGoBack, error) { this.setFakeStateForMethod_( - 'getCurrentState', state, canCancel, canGoBack, error); + 'getCurrentState', state, canExit, canGoBack, error); } /** * Sets the value that will be returned when calling * transitionPreviousState(). * @param {!State} state - * @param {boolean} canCancel, + * @param {boolean} canExit, * @param {boolean} canGoBack, * @param {!RmadErrorCode} error * @private */ - setFakePrevState_(state, canCancel, canGoBack, error) { + setFakePrevState_(state, canExit, canGoBack, error) { this.setFakeStateForMethod_( - 'transitionPreviousState', state, canCancel, canGoBack, error); + 'transitionPreviousState', state, canExit, canGoBack, error); } /** @@ -1415,15 +1415,15 @@ * that update state. e.g. setSameOwner() * @param {string} method * @param {!State} state - * @param {boolean} canCancel, + * @param {boolean} canExit, * @param {boolean} canGoBack, * @param {!RmadErrorCode} error * @private */ - setFakeStateForMethod_(method, state, canCancel, canGoBack, error) { + setFakeStateForMethod_(method, state, canExit, canGoBack, error) { this.methods_.setResult(method, /** @type {!StateResult} */ ({ state: state, - canCancel: canCancel, + canExit: canExit, canGoBack: canGoBack, error: error }));
diff --git a/ash/webui/shimless_rma/resources/shimless_rma.js b/ash/webui/shimless_rma/resources/shimless_rma.js index 2ca89ed2..cc60f8e9 100644 --- a/ash/webui/shimless_rma/resources/shimless_rma.js +++ b/ash/webui/shimless_rma/resources/shimless_rma.js
@@ -512,7 +512,7 @@ if (error === RmadErrorCode.kRmaNotRequired) { const errorState = { state: State.kUnknown, - canCancel: false, + canExit: false, canGoBack: false, error: RmadErrorCode.kRmaNotRequired }; @@ -548,7 +548,7 @@ // Set the next page as the current page. this.currentPage_ = nextStatePageInfo; - if (!stateResult.canCancel) { + if (!stateResult.canExit) { this.currentPage_.buttonExit = ButtonState.HIDDEN; } if (!stateResult.canGoBack) {
diff --git a/ash/webui/shimless_rma/resources/shimless_rma_types.js b/ash/webui/shimless_rma/resources/shimless_rma_types.js index e4b0c142..82fe973 100644 --- a/ash/webui/shimless_rma/resources/shimless_rma_types.js +++ b/ash/webui/shimless_rma/resources/shimless_rma_types.js
@@ -21,7 +21,7 @@ * this is used frequently. * @typedef {{ * state: !State, - * canCancel: boolean, + * canExit: boolean, * canGoBack: boolean, * error: !RmadErrorCode * }}
diff --git a/ash/wm/desks/templates/restore_data_collector.cc b/ash/wm/desks/templates/restore_data_collector.cc index a115aa1..e262f6b 100644 --- a/ash/wm/desks/templates/restore_data_collector.cc +++ b/ash/wm/desks/templates/restore_data_collector.cc
@@ -54,8 +54,11 @@ shell->mru_window_tracker()->BuildMruWindowList(kActiveDesk); auto* delegate = shell->desks_templates_delegate(); for (auto* window : mru_windows) { - if (!delegate->IsWindowSupportedForDeskTemplate(window) && - !wm::GetTransientParent(window)) { + // Skip transient windows without reporting. + if (wm::GetTransientParent(window)) + continue; + + if (!delegate->IsWindowSupportedForDeskTemplate(window)) { call.unsupported_apps.push_back(window); continue; }
diff --git a/ash/wm/desks/templates/saved_desk_item_view.cc b/ash/wm/desks/templates/saved_desk_item_view.cc index 364e3a3..d3c9c1c 100644 --- a/ash/wm/desks/templates/saved_desk_item_view.cc +++ b/ash/wm/desks/templates/saved_desk_item_view.cc
@@ -36,6 +36,7 @@ #include "base/strings/utf_string_conversions.h" #include "base/time/time.h" #include "chromeos/ui/vector_icons/vector_icons.h" +#include "ui/accessibility/ax_enums.mojom.h" #include "ui/base/l10n/l10n_util.h" #include "ui/base/l10n/time_format.h" #include "ui/base/metadata/metadata_impl_macros.h" @@ -339,6 +340,25 @@ name_view_->OnContentsChanged(); } +void SavedDeskItemView::GetAccessibleNodeData(ui::AXNodeData* node_data) { + int accessible_text_id = + desk_template_->type() == DeskTemplateType::kTemplate + ? IDS_ASH_DESKS_TEMPLATES_LIBRARY_TEMPLATES_GRID_ITEM_ACCESSIBLE_NAME + : IDS_ASH_DESKS_TEMPLATES_LIBRARY_SAVE_AND_RECALL_GRID_ITEM_ACCESSIBLE_NAME; + + node_data->role = ax::mojom::Role::kButton; + + node_data->AddStringAttribute( + ax::mojom::StringAttribute::kName, + l10n_util::GetStringFUTF8(accessible_text_id, + desk_template_->template_name())); + + node_data->AddStringAttribute( + ax::mojom::StringAttribute::kDescription, + l10n_util::GetStringUTF8( + IDS_ASH_DESKS_TEMPLATES_LIBRARY_SAVED_DESK_GRID_ITEM_EXTRA_ACCESSIBLE_DESCRIPTION)); +} + void SavedDeskItemView::Layout() { views::View::Layout();
diff --git a/ash/wm/desks/templates/saved_desk_item_view.h b/ash/wm/desks/templates/saved_desk_item_view.h index 7526dbe..60f59884 100644 --- a/ash/wm/desks/templates/saved_desk_item_view.h +++ b/ash/wm/desks/templates/saved_desk_item_view.h
@@ -10,6 +10,7 @@ #include "ash/wm/overview/overview_highlightable_view.h" #include "base/memory/weak_ptr.h" #include "base/scoped_observation.h" +#include "ui/accessibility/ax_node_data.h" #include "ui/base/metadata/metadata_header_macros.h" #include "ui/views/controls/button/button.h" #include "ui/views/controls/textfield/textfield_controller.h" @@ -108,6 +109,7 @@ void UpdateTemplate(const DeskTemplate& updated_template); // views::Button: + void GetAccessibleNodeData(ui::AXNodeData* node_data) override; void Layout() override; void OnThemeChanged() override; void OnViewFocused(views::View* observed_view) override;
diff --git a/base/allocator/partition_allocator/BUILD.gn b/base/allocator/partition_allocator/BUILD.gn index 1e14a2c..4e817e4 100644 --- a/base/allocator/partition_allocator/BUILD.gn +++ b/base/allocator/partition_allocator/BUILD.gn
@@ -121,6 +121,7 @@ "partition_alloc_base/time/time.h", "partition_alloc_base/time/time_override.cc", "partition_alloc_base/time/time_override.h", + "partition_alloc_base/win/windows_types.h", "partition_alloc_check.h", "partition_alloc_config.h", "partition_alloc_constants.h",
diff --git a/base/allocator/partition_allocator/DEPS b/base/allocator/partition_allocator/DEPS index 17698142..56a1d09 100644 --- a/base/allocator/partition_allocator/DEPS +++ b/base/allocator/partition_allocator/DEPS
@@ -12,7 +12,6 @@ "+base/mac/mac_util.h", "+base/mac/scoped_cftyperef.h", "+base/debug/debugging_buildflags.h", - "+base/win/windows_types.h", "+build/build_config.h", "+build/buildflag.h", "+build/chromeos_buildflags.h",
diff --git a/base/allocator/partition_allocator/oom.h b/base/allocator/partition_allocator/oom.h index b16d777e..a0d564f 100644 --- a/base/allocator/partition_allocator/oom.h +++ b/base/allocator/partition_allocator/oom.h
@@ -13,7 +13,7 @@ #include "build/build_config.h" #if BUILDFLAG(IS_WIN) -#include "base/win/windows_types.h" +#include "base/allocator/partition_allocator/partition_alloc_base/win/windows_types.h" #endif namespace partition_alloc {
diff --git a/base/allocator/partition_allocator/partition_alloc_base/threading/platform_thread.h b/base/allocator/partition_allocator/partition_alloc_base/threading/platform_thread.h index 30f4aab..62a4dec 100644 --- a/base/allocator/partition_allocator/partition_alloc_base/threading/platform_thread.h +++ b/base/allocator/partition_allocator/partition_alloc_base/threading/platform_thread.h
@@ -20,7 +20,7 @@ #include "build/build_config.h" #if BUILDFLAG(IS_WIN) -#include "base/win/windows_types.h" +#include "base/allocator/partition_allocator/partition_alloc_base/win/windows_types.h" #elif BUILDFLAG(IS_FUCHSIA) #include <zircon/types.h> #elif BUILDFLAG(IS_APPLE)
diff --git a/base/allocator/partition_allocator/partition_alloc_base/threading/platform_thread_ref.h b/base/allocator/partition_allocator/partition_alloc_base/threading/platform_thread_ref.h index 0ba3a41..c775edf 100644 --- a/base/allocator/partition_allocator/partition_alloc_base/threading/platform_thread_ref.h +++ b/base/allocator/partition_allocator/partition_alloc_base/threading/platform_thread_ref.h
@@ -17,7 +17,7 @@ #include "build/build_config.h" #if BUILDFLAG(IS_WIN) -#include "base/win/windows_types.h" +#include "base/allocator/partition_allocator/partition_alloc_base/win/windows_types.h" #elif BUILDFLAG(IS_POSIX) || BUILDFLAG(IS_FUCHSIA) #include <pthread.h> #endif
diff --git a/base/allocator/partition_allocator/partition_alloc_base/time/time.h b/base/allocator/partition_allocator/partition_alloc_base/time/time.h index f7888dd..e596d99 100644 --- a/base/allocator/partition_allocator/partition_alloc_base/time/time.h +++ b/base/allocator/partition_allocator/partition_alloc_base/time/time.h
@@ -96,7 +96,7 @@ #endif #if BUILDFLAG(IS_WIN) -#include "base/win/windows_types.h" +#include "base/allocator/partition_allocator/partition_alloc_base/win/windows_types.h" namespace ABI { namespace Windows {
diff --git a/base/allocator/partition_allocator/partition_alloc_base/win/windows_types.h b/base/allocator/partition_allocator/partition_alloc_base/win/windows_types.h new file mode 100644 index 0000000..036ffd7 --- /dev/null +++ b/base/allocator/partition_allocator/partition_alloc_base/win/windows_types.h
@@ -0,0 +1,88 @@ +// Copyright (c) 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// This file contains defines and typedefs that allow popular Windows types to +// be used without the overhead of including windows.h. + +#ifndef BASE_ALLOCATOR_PARTITION_ALLOCATOR_PARTITION_ALLOC_BASE_WIN_WINDOWS_TYPES_H_ +#define BASE_ALLOCATOR_PARTITION_ALLOCATOR_PARTITION_ALLOC_BASE_WIN_WINDOWS_TYPES_H_ + +// Needed for function prototypes. +#include <specstrings.h> + +#ifdef __cplusplus +extern "C" { +#endif + +// typedef and define the most commonly used Windows integer types. + +typedef unsigned long DWORD; +typedef long LONG; +typedef __int64 LONGLONG; +typedef unsigned __int64 ULONGLONG; + +#define VOID void +typedef char CHAR; +typedef short SHORT; +typedef long LONG; +typedef int INT; +typedef unsigned int UINT; +typedef unsigned int* PUINT; +typedef unsigned __int64 UINT64; +typedef void* LPVOID; +typedef void* PVOID; +typedef void* HANDLE; +typedef int BOOL; +typedef unsigned char BYTE; +typedef BYTE BOOLEAN; +typedef DWORD ULONG; +typedef unsigned short WORD; +typedef WORD UWORD; +typedef WORD ATOM; + +// Forward declare some Windows struct/typedef sets. + +typedef struct _RTL_SRWLOCK RTL_SRWLOCK; +typedef RTL_SRWLOCK SRWLOCK, *PSRWLOCK; + +typedef struct _FILETIME FILETIME; + +struct PA_CHROME_SRWLOCK { + PVOID Ptr; +}; + +// The trailing white-spaces after this macro are required, for compatibility +// with the definition in winnt.h. +#define RTL_SRWLOCK_INIT {0} // NOLINT +#define SRWLOCK_INIT RTL_SRWLOCK_INIT + +// clang-format on + +// Define some macros needed when prototyping Windows functions. + +#define DECLSPEC_IMPORT __declspec(dllimport) +#define WINBASEAPI DECLSPEC_IMPORT +#define WINAPI __stdcall + +// Needed for LockImpl. +WINBASEAPI _Releases_exclusive_lock_(*SRWLock) VOID WINAPI + ReleaseSRWLockExclusive(_Inout_ PSRWLOCK SRWLock); +WINBASEAPI BOOLEAN WINAPI TryAcquireSRWLockExclusive(_Inout_ PSRWLOCK SRWLock); + +// Needed for thread_local_storage.h +WINBASEAPI LPVOID WINAPI TlsGetValue(_In_ DWORD dwTlsIndex); + +WINBASEAPI BOOL WINAPI TlsSetValue(_In_ DWORD dwTlsIndex, + _In_opt_ LPVOID lpTlsValue); + +WINBASEAPI _Check_return_ _Post_equals_last_error_ DWORD WINAPI + GetLastError(VOID); + +WINBASEAPI VOID WINAPI SetLastError(_In_ DWORD dwErrCode); + +#ifdef __cplusplus +} +#endif + +#endif // BASE_ALLOCATOR_PARTITION_ALLOCATOR_PARTITION_ALLOC_BASE_WIN_WINDOWS_TYPES_H_
diff --git a/base/allocator/partition_allocator/partition_tls.h b/base/allocator/partition_allocator/partition_tls.h index ec90b7b7..4f8c699 100644 --- a/base/allocator/partition_allocator/partition_tls.h +++ b/base/allocator/partition_allocator/partition_tls.h
@@ -16,7 +16,7 @@ #endif #if BUILDFLAG(IS_WIN) -#include "base/win/windows_types.h" +#include "base/allocator/partition_allocator/partition_alloc_base/win/windows_types.h" #endif // Barebones TLS implementation for use in PartitionAlloc. This doesn't use the
diff --git a/base/allocator/partition_allocator/spinning_mutex.h b/base/allocator/partition_allocator/spinning_mutex.h index a6f92c9..480ec02b 100644 --- a/base/allocator/partition_allocator/spinning_mutex.h +++ b/base/allocator/partition_allocator/spinning_mutex.h
@@ -18,7 +18,7 @@ #include "build/build_config.h" #if BUILDFLAG(IS_WIN) -#include "base/win/windows_types.h" +#include "base/allocator/partition_allocator/partition_alloc_base/win/windows_types.h" #endif #if BUILDFLAG(IS_POSIX) @@ -98,7 +98,7 @@ std::atomic<int32_t> state_{kUnlocked}; #elif BUILDFLAG(IS_WIN) - CHROME_SRWLOCK lock_ = SRWLOCK_INIT; + PA_CHROME_SRWLOCK lock_ = SRWLOCK_INIT; #elif BUILDFLAG(IS_POSIX) pthread_mutex_t lock_ = PTHREAD_MUTEX_INITIALIZER; #elif BUILDFLAG(IS_FUCHSIA)
diff --git a/chrome/android/features/start_surface/javatests/src/org/chromium/chrome/features/start_surface/StartSurfaceFinaleTest.java b/chrome/android/features/start_surface/javatests/src/org/chromium/chrome/features/start_surface/StartSurfaceFinaleTest.java index c25ed11..e1c990c 100644 --- a/chrome/android/features/start_surface/javatests/src/org/chromium/chrome/features/start_surface/StartSurfaceFinaleTest.java +++ b/chrome/android/features/start_surface/javatests/src/org/chromium/chrome/features/start_surface/StartSurfaceFinaleTest.java
@@ -46,6 +46,7 @@ import org.chromium.base.test.util.CallbackHelper; import org.chromium.base.test.util.CommandLineFlags; import org.chromium.base.test.util.CriteriaHelper; +import org.chromium.base.test.util.DisabledTest; import org.chromium.base.test.util.Feature; import org.chromium.base.test.util.Restriction; import org.chromium.chrome.R; @@ -146,6 +147,7 @@ @LargeTest @Feature({"StartSurface"}) @CommandLineFlags.Add({START_SURFACE_TEST_BASE_PARAMS + "/omnibox_focused_on_new_tab/true"}) + @DisabledTest(message = "This test blocks http://crrev/c/3665321") public void testOmnibox_FocusedOnNewTabInSingleSurface() { if (!mImmediateReturn) { StartSurfaceTestUtils.pressHomePageButton(mActivityTestRule.getActivity()); @@ -201,6 +203,7 @@ // clang-format off @CommandLineFlags.Add({START_SURFACE_TEST_BASE_PARAMS + "/show_last_active_tab_only/true" + "/exclude_mv_tiles/true/omnibox_focused_on_new_tab/true"}) + @DisabledTest(message = "This test blocks http://crrev/c/3665321") public void testOmnibox_FocusedOnNewTabInSingleSurfaceV2() { // clang-format on if (!mImmediateReturn) {
diff --git a/chrome/android/features/tab_ui/java/src/org/chromium/chrome/browser/tasks/tab_management/TabListRecyclerView.java b/chrome/android/features/tab_ui/java/src/org/chromium/chrome/browser/tasks/tab_management/TabListRecyclerView.java index 630576c..5c89dd3 100644 --- a/chrome/android/features/tab_ui/java/src/org/chromium/chrome/browser/tasks/tab_management/TabListRecyclerView.java +++ b/chrome/android/features/tab_ui/java/src/org/chromium/chrome/browser/tasks/tab_management/TabListRecyclerView.java
@@ -149,6 +149,8 @@ private ImageView mShadowImageView; private int mShadowTopOffset; private TabListOnScrollListener mScrollListener; + // It is null when gts-tab animation is disabled or switching from Start surface to GTS. + @Nullable private RecyclerView.ItemAnimator mOriginalAnimator; /** @@ -203,7 +205,12 @@ mFadeInAnimator = null; mListener.finishedShowing(); // Restore the original value. - setItemAnimator(mOriginalAnimator); + // TODO(crbug.com/1315676): Remove the null check after decoupling Start surface + // layout and grid tab switcher layout. + if (mOriginalAnimator != null) { + setItemAnimator(mOriginalAnimator); + mOriginalAnimator = null; + } setShadowVisibility(computeVerticalScrollOffset() > 0); if (mDynamicView != null) { mDynamicView.dropCachedBitmap();
diff --git a/chrome/android/java/src/org/chromium/chrome/browser/query_tiles/QueryTileUtils.java b/chrome/android/java/src/org/chromium/chrome/browser/query_tiles/QueryTileUtils.java index c521146..d60da4e 100644 --- a/chrome/android/java/src/org/chromium/chrome/browser/query_tiles/QueryTileUtils.java +++ b/chrome/android/java/src/org/chromium/chrome/browser/query_tiles/QueryTileUtils.java
@@ -18,9 +18,9 @@ import org.chromium.chrome.browser.preferences.SharedPreferencesManager; import org.chromium.chrome.browser.profiles.Profile; import org.chromium.chrome.browser.segmentation_platform.SegmentationPlatformServiceFactory; -import org.chromium.components.optimization_guide.proto.ModelsProto.OptimizationTarget; import org.chromium.components.segmentation_platform.SegmentSelectionResult; import org.chromium.components.segmentation_platform.SegmentationPlatformService; +import org.chromium.components.segmentation_platform.proto.SegmentationProto.SegmentId; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; @@ -243,7 +243,7 @@ if (!result.isReady) { segmentationResult = ShowQueryTilesSegmentationResult.UNINITIALIZED; } else if (result.selectedSegment - == OptimizationTarget.OPTIMIZATION_TARGET_SEGMENTATION_QUERY_TILES) { + == SegmentId.OPTIMIZATION_TARGET_SEGMENTATION_QUERY_TILES) { segmentationResult = ShowQueryTilesSegmentationResult.SHOW; } else { segmentationResult = ShowQueryTilesSegmentationResult.DONT_SHOW;
diff --git a/chrome/android/java/src/org/chromium/chrome/browser/tasks/ReturnToChromeUtil.java b/chrome/android/java/src/org/chromium/chrome/browser/tasks/ReturnToChromeUtil.java index 0005dda..577fec7 100644 --- a/chrome/android/java/src/org/chromium/chrome/browser/tasks/ReturnToChromeUtil.java +++ b/chrome/android/java/src/org/chromium/chrome/browser/tasks/ReturnToChromeUtil.java
@@ -59,8 +59,8 @@ import org.chromium.components.browser_ui.widget.gesture.BackPressHandler; import org.chromium.components.embedder_support.util.UrlConstants; import org.chromium.components.embedder_support.util.UrlUtilities; -import org.chromium.components.optimization_guide.proto.ModelsProto.OptimizationTarget; import org.chromium.components.segmentation_platform.SegmentationPlatformService; +import org.chromium.components.segmentation_platform.proto.SegmentationProto.SegmentId; import org.chromium.components.signin.identitymanager.ConsentLevel; import org.chromium.components.user_prefs.UserPrefs; import org.chromium.content_public.browser.LoadUrlParams; @@ -842,7 +842,7 @@ if (!result.isReady) { resultEnum = ShowChromeStartSegmentationResult.UNINITIALIZED; } else if (result.selectedSegment - == OptimizationTarget.OPTIMIZATION_TARGET_SEGMENTATION_CHROME_START_ANDROID) { + == SegmentId.OPTIMIZATION_TARGET_SEGMENTATION_CHROME_START_ANDROID) { resultEnum = ShowChromeStartSegmentationResult.SHOW; } else { resultEnum = ShowChromeStartSegmentationResult.DONT_SHOW;
diff --git a/chrome/app/chromeos_shared_strings.grdp b/chrome/app/chromeos_shared_strings.grdp index c3bcbe7..ca1c739 100644 --- a/chrome/app/chromeos_shared_strings.grdp +++ b/chrome/app/chromeos_shared_strings.grdp
@@ -9,4 +9,18 @@ Accounts </message> + <!-- RequestPin dialog messages --> + <message name="IDS_REQUEST_PIN_DIALOG_HEADER" desc="The text displayed in the certificate provider PIN request dialog."> + "<ph name="EXTENSION_NAME">$1<ex>My Extension</ex></ph>" is requesting your <ph name="CODE_TYPE">$2<ex>PIN</ex></ph> + </message> + <message name="IDS_REQUEST_PIN_DIALOG_PROCESSING" desc="The text displayed while the certificate provider API is waiting for response from extension."> + Processing request... + </message> + <message name="IDS_REQUEST_PIN_DIALOG_PIN" desc="The Provider Identification Number abbreviation"> + PIN + </message> + <message name="IDS_REQUEST_PIN_DIALOG_PUK" desc="The Personal Unlocking Key (as used in mobile phones) abbreviation"> + PUK + </message> + </grit-part>
diff --git a/chrome/app/chromeos_strings.grdp b/chrome/app/chromeos_strings.grdp index bdf9a94..a85481e 100644 --- a/chrome/app/chromeos_strings.grdp +++ b/chrome/app/chromeos_strings.grdp
@@ -4367,20 +4367,6 @@ <ph name="PRINTER_NAME">$1<ex> Google InkJet 1234</ex></ph> is connected but needs configuration </message> - <!-- RequestPin dialog messages --> - <message name="IDS_REQUEST_PIN_DIALOG_HEADER" desc="The text displayed in the certificate provider PIN request dialog."> - "<ph name="EXTENSION_NAME">$1<ex>My Extension</ex></ph>" is requesting your <ph name="CODE_TYPE">$2<ex>PIN</ex></ph> - </message> - <message name="IDS_REQUEST_PIN_DIALOG_PROCESSING" desc="The text displayed while the certificate provider API is waiting for response from extension."> - Processing request... - </message> - <message name="IDS_REQUEST_PIN_DIALOG_PIN" desc="The Provider Identification Number abbreviation"> - PIN - </message> - <message name="IDS_REQUEST_PIN_DIALOG_PUK" desc="The Personal Unlocking Key (as used in mobile phones) abbreviation"> - PUK - </message> - <message name="IDS_USB_PRINTER_UNKNOWN_DISPLAY_NAME" desc="Display string used for USB printers where we have identified neither the correct manufactuer nor model strings. USB stands for Universal Serial Bus, it indicates how the printer is attached to the computer."> Unknown Printer (USB) </message>
diff --git a/chrome/browser/BUILD.gn b/chrome/browser/BUILD.gn index 0e09c0b7..59d7970a 100644 --- a/chrome/browser/BUILD.gn +++ b/chrome/browser/BUILD.gn
@@ -5525,6 +5525,8 @@ "media/platform_verification_chromeos.h", "memory/oom_kills_monitor.cc", "memory/oom_kills_monitor.h", + "notifications/passphrase_textfield.cc", + "notifications/passphrase_textfield.h", "obsolete_system/obsolete_system_stub.cc", "platform_keys/extension_key_permissions_service.cc", "platform_keys/extension_key_permissions_service.h", @@ -6668,6 +6670,8 @@ "media/webrtc/tab_capture_access_handler.h", "metrics/extensions_metrics_provider.cc", "metrics/extensions_metrics_provider.h", + "performance_manager/extension_watcher.cc", + "performance_manager/extension_watcher.h", "policy/chrome_extension_policy_migrator.cc", "policy/chrome_extension_policy_migrator.h", "renderer_context_menu/context_menu_content_type_app_mode.cc", @@ -7712,6 +7716,10 @@ if (is_android) { deps += [ "//chrome/browser/ui/webui/feed_internals:mojo_bindings_js" ] } + + if (is_android || is_linux || is_chromeos || is_win) { + deps += [ "//chrome/browser/resources/sandbox_internals:build_ts" ] + } } action("expired_flags_list_gen") {
diff --git a/chrome/browser/about_flags.cc b/chrome/browser/about_flags.cc index a006cfa..6466f81 100644 --- a/chrome/browser/about_flags.cc +++ b/chrome/browser/about_flags.cc
@@ -3483,6 +3483,10 @@ {"adaptive-charging", flag_descriptions::kAdaptiveChargingName, flag_descriptions::kAdaptiveChargingDescription, kOsCrOS, FEATURE_VALUE_TYPE(ash::features::kAdaptiveCharging)}, + {"adaptive-charging-for-testing", + flag_descriptions::kAdaptiveChargingForTestingName, + flag_descriptions::kAdaptiveChargingForTestingDescription, kOsCrOS, + FEATURE_VALUE_TYPE(ash::features::kAdaptiveChargingForTesting)}, {"allow-poly-device-pairing", flag_descriptions::kAllowPolyDevicePairingName, flag_descriptions::kAllowPolyDevicePairingDescription, kOsCrOS, @@ -7147,6 +7151,10 @@ {"launcher-game-search", flag_descriptions::kLauncherGameSearchName, flag_descriptions::kLauncherGameSearchDescription, kOsCrOS, FEATURE_VALUE_TYPE(search_features::kLauncherGameSearch)}, + {"launcher-hide-continue-section", + flag_descriptions::kLauncherHideContinueSectionName, + flag_descriptions::kLauncherHideContinueSectionDescription, kOsCrOS, + FEATURE_VALUE_TYPE(ash::features::kLauncherHideContinueSection)}, {"launcher-nudge", flag_descriptions::kLauncherNudgeName, flag_descriptions::kLauncherNudgeDescription, kOsCrOS, FEATURE_VALUE_TYPE(ash::features::kShelfLauncherNudge)}, @@ -7337,10 +7345,6 @@ #endif #if BUILDFLAG(IS_ANDROID) - {"android-detailed-language-settings", - flag_descriptions::kAndroidDetailedLanguageSettingsName, - flag_descriptions::kAndroidDetailedLanguageSettingsDescription, kOsAndroid, - FEATURE_VALUE_TYPE(language::kDetailedLanguageSettings)}, {"android-force-app-language-prompt", flag_descriptions::kAndroidForceAppLanguagePromptName, flag_descriptions::kAndroidForceAppLanguagePromptDescription, kOsAndroid,
diff --git a/chrome/browser/ash/accessibility/spoken_feedback_browsertest.cc b/chrome/browser/ash/accessibility/spoken_feedback_browsertest.cc index 623d048..a1e68e52 100644 --- a/chrome/browser/ash/accessibility/spoken_feedback_browsertest.cc +++ b/chrome/browser/ash/accessibility/spoken_feedback_browsertest.cc
@@ -1740,8 +1740,10 @@ // has the same name as the desk it was created from, in this case the default // desk name is "Desk 1". sm_.Call([this]() { SendKeyPress(ui::VKEY_TAB); }); - sm_.ExpectSpeechPattern("Desk 1"); + sm_.ExpectSpeechPattern("Template, Desk 1"); sm_.ExpectSpeech("Button"); + sm_.ExpectSpeech("Press Ctrl plus W to close"); + sm_.ExpectSpeech("Press Search plus Space to activate"); // The next item is the textfield inside the template card, which also has the // same name as the desk it was created from.
diff --git a/chrome/browser/ash/crosapi/browser_util_unittest.cc b/chrome/browser/ash/crosapi/browser_util_unittest.cc index defae9c..604155b 100644 --- a/chrome/browser/ash/crosapi/browser_util_unittest.cc +++ b/chrome/browser/ash/crosapi/browser_util_unittest.cc
@@ -20,7 +20,6 @@ #include "chrome/browser/ash/crosapi/idle_service_ash.h" #include "chrome/browser/ash/login/users/fake_chrome_user_manager.h" #include "chrome/browser/ash/profiles/profile_helper.h" -#include "chrome/browser/policy/profile_policy_connector.h" #include "chrome/test/base/scoped_testing_local_state.h" #include "chrome/test/base/testing_browser_process.h" #include "chrome/test/base/testing_profile.h" @@ -219,8 +218,7 @@ base::test::ScopedFeatureList feature_list; feature_list.InitAndEnableFeature(chromeos::features::kLacrosSupport); AddRegularUser("user@managedchrome.com"); - testing_profile_.GetProfilePolicyConnector()->OverrideIsManagedForTesting( - true); + { ScopedLacrosAvailabilityCache cache(LacrosAvailability::kLacrosDisallowed); EXPECT_FALSE(browser_util::IsLacrosEnabled()); @@ -253,8 +251,6 @@ TEST_F(LacrosSupportBrowserUtilTest, AshWebBrowserEnabled) { base::test::ScopedFeatureList feature_list; AddRegularUser("user@managedchrome.com"); - testing_profile_.GetProfilePolicyConnector()->OverrideIsManagedForTesting( - true); // Lacros is not allowed. { @@ -304,8 +300,6 @@ TEST_F(BrowserUtilTest, IsAshWebBrowserDisabled) { base::test::ScopedFeatureList feature_list; AddRegularUser("user@managedchrome.com"); - testing_profile_.GetProfilePolicyConnector()->OverrideIsManagedForTesting( - true); ScopedLacrosAvailabilityCache cache(LacrosAvailability::kLacrosOnly); // Lacros is allowed and enabled and is the only browser by policy. @@ -388,8 +382,6 @@ base::test::ScopedFeatureList feature_list; feature_list.InitAndEnableFeature(chromeos::features::kLacrosSupport); AddRegularUser("user@managedchrome.com"); - testing_profile_.GetProfilePolicyConnector()->OverrideIsManagedForTesting( - true); { ScopedLacrosAvailabilityCache cache(LacrosAvailability::kLacrosDisallowed);
diff --git a/chrome/browser/ash/crostini/crostini_terminal.cc b/chrome/browser/ash/crostini/crostini_terminal.cc index 2b00800..4c67ae57 100644 --- a/chrome/browser/ash/crostini/crostini_terminal.cc +++ b/chrome/browser/ash/crostini/crostini_terminal.cc
@@ -95,45 +95,51 @@ apps::AppLaunchParams params) { // This function is called asynchronously, so we need to check whether // `profile` is still valid first. - if (g_browser_process) { - auto* profile_manager = g_browser_process->profile_manager(); - if (profile_manager && profile_manager->IsValidProfile(profile)) { - // This LaunchSystemWebAppImpl call is necessary. Terminal App uses its - // own CrostiniApps publisher for launching. Calling - // LaunchSystemWebAppAsync would ask AppService to launch the App, which - // routes the launch request to this function, resulting in a loop. - // - // System Web Apps managed by Web App publisher should call - // LaunchSystemWebAppAsync. - - // Launch without a pinned home tab (settings page). - if (params.disposition == WindowOpenDisposition::NEW_POPUP) { - web_app::LaunchSystemWebAppImpl( - profile, ash::SystemWebAppType::TERMINAL, url, params); - return; - } - - // TODO(crbug.com/1308961): Migrate to use PWA pinned home tab when ready. - // If opening a new tab, first pin home tab. - full_restore::FullRestoreSaveHandler::GetInstance(); - GURL home(base::StrCat( - {chrome::kChromeUIUntrustedTerminalURL, kTerminalHomePath})); - Browser* browser = web_app::LaunchSystemWebAppImpl( - profile, ash::SystemWebAppType::TERMINAL, home, params); - if (url != home) { - chrome::AddTabAt(browser, url, /*index=*/1, /*foreground=*/true); - } - auto info = std::make_unique<app_restore::AppLaunchInfo>( - kCrostiniTerminalSystemAppId, browser->session_id().id(), - params.container, params.disposition, params.display_id, - std::vector<base::FilePath>{}, nullptr); - full_restore::SaveAppLaunchInfo(browser->profile()->GetPath(), - std::move(info)); - - return; - } + if (!g_browser_process) { + LOG(WARNING) << "Abort launching terminal, invalid browser process."; + return; } - LOG(WARNING) << "Profile becomes invalid. Abort launching terminal."; + + auto* profile_manager = g_browser_process->profile_manager(); + if (!profile_manager || !profile_manager->IsValidProfile(profile)) { + LOG(WARNING) << "Abort launching terminal, invalid profile."; + return; + } + + // This LaunchSystemWebAppImpl call is necessary. Terminal App uses its + // own CrostiniApps publisher for launching. Calling + // LaunchSystemWebAppAsync would ask AppService to launch the App, which + // routes the launch request to this function, resulting in a loop. + // + // System Web Apps managed by Web App publisher should call + // LaunchSystemWebAppAsync. + + // Launch without a pinned home tab (settings page). + if (params.disposition == WindowOpenDisposition::NEW_POPUP) { + web_app::LaunchSystemWebAppImpl(profile, ash::SystemWebAppType::TERMINAL, + url, params); + return; + } + + // TODO(crbug.com/1308961): Migrate to use PWA pinned home tab when ready. + // If opening a new tab, first pin home tab. + full_restore::FullRestoreSaveHandler::GetInstance(); + GURL home( + base::StrCat({chrome::kChromeUIUntrustedTerminalURL, kTerminalHomePath})); + Browser* browser = web_app::LaunchSystemWebAppImpl( + profile, ash::SystemWebAppType::TERMINAL, home, params); + if (!browser) { + return; + } + if (url != home) { + chrome::AddTabAt(browser, url, /*index=*/1, /*foreground=*/true); + } + auto info = std::make_unique<app_restore::AppLaunchInfo>( + kCrostiniTerminalSystemAppId, browser->session_id().id(), + params.container, params.disposition, params.display_id, + std::vector<base::FilePath>{}, nullptr); + full_restore::SaveAppLaunchInfo(browser->profile()->GetPath(), + std::move(info)); } } // namespace
diff --git a/chrome/browser/ash/input_method/assistive_suggester.cc b/chrome/browser/ash/input_method/assistive_suggester.cc index a4ba07f..43455f4 100644 --- a/chrome/browser/ash/input_method/assistive_suggester.cc +++ b/chrome/browser/ash/input_method/assistive_suggester.cc
@@ -8,8 +8,10 @@ #include "ash/constants/ash_pref_names.h" #include "ash/public/cpp/window_properties.h" #include "ash/services/ime/public/cpp/suggestions.h" +#include "base/containers/fixed_flat_set.h" #include "base/feature_list.h" #include "base/hash/hash.h" +#include "base/location.h" #include "base/metrics/histogram_functions.h" #include "base/metrics/user_metrics.h" #include "base/strings/string_util.h" @@ -40,6 +42,16 @@ const char kMaxTextBeforeCursorLength = 50; +constexpr base::TimeDelta kLongpressActivationDelay = base::Seconds(1); + +// TODO(b/217560706): Make this different based on current engine after research +// is conducted. +constexpr auto kDefaultLongpressEnabledKeys = base::MakeFixedFlatSet<char>( + {'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', + 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', + 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z'}); + void RecordAssistiveMatch(AssistiveType type) { base::UmaHistogramEnumeration("InputMethod.Assistive.Match", type); @@ -177,6 +189,7 @@ personal_data_manager_for_testing), emoji_suggester_(suggestion_handler, profile), multi_word_suggester_(suggestion_handler, profile), + longpress_diacritics_suggester_(suggestion_handler), suggester_switch_(std::move(suggester_switch)) { RecordAssistiveUserPrefForPersonalInfo( profile_->GetPrefs()->GetBoolean(prefs::kAssistPersonalInfoEnabled)); @@ -188,7 +201,9 @@ bool AssistiveSuggester::IsAssistiveFeatureEnabled() { return IsAssistPersonalInfoEnabled() || IsEmojiSuggestAdditionEnabled() || - IsMultiWordSuggestEnabled() || IsEnhancedEmojiSuggestEnabled(); + IsMultiWordSuggestEnabled() || IsEnhancedEmojiSuggestEnabled() || + base::FeatureList::IsEnabled( + features::kDiacriticsOnPhysicalKeyboardLongpress); } void AssistiveSuggester::FetchEnabledSuggestionsFromBrowserContextThen( @@ -337,6 +352,7 @@ personal_info_suggester_.OnFocus(context_id); emoji_suggester_.OnFocus(context_id); multi_word_suggester_.OnFocus(context_id); + longpress_diacritics_suggester_.OnFocus(context_id); suggester_switch_->FetchEnabledSuggestionsThen( base::BindOnce(&AssistiveSuggester::RecordTextInputStateMetrics, weak_ptr_factory_.GetWeakPtr())); @@ -347,6 +363,7 @@ personal_info_suggester_.OnBlur(); emoji_suggester_.OnBlur(); multi_word_suggester_.OnBlur(); + longpress_diacritics_suggester_.OnBlur(); } bool AssistiveSuggester::OnKeyEvent(const ui::KeyEvent& event) { @@ -375,7 +392,56 @@ } } + // Longpress diacritics behaviour overrides the longpress to repeat key + // behaviour for alphabetical keys. + if (base::FeatureList::IsEnabled( + features::kDiacriticsOnPhysicalKeyboardLongpress) && + event.is_repeat() && + kDefaultLongpressEnabledKeys.contains(event.GetCharacter())) { + return true; // Do not propagate this event. + } + + HandleLongpressEnabledKeyEvent(event); + return false; +}; // namespace input_method + +void AssistiveSuggester::HandleLongpressEnabledKeyEvent( + const ui::KeyEvent& event) { + if (const char c = event.GetCharacter(); + kDefaultLongpressEnabledKeys.contains(c) && + base::FeatureList::IsEnabled( + features::kDiacriticsOnPhysicalKeyboardLongpress)) { + // Process longpress keydown event. + if (current_longpress_char_ == absl::nullopt && + event.type() == ui::EventType::ET_KEY_PRESSED) { + current_longpress_char_ = c; + longpress_timer_.Start( + FROM_HERE, kLongpressActivationDelay, + base::BindOnce(&AssistiveSuggester::OnLongpressDetected, + weak_ptr_factory_.GetWeakPtr())); + return; + } + + // Process longpress interrupted event (key press up before timer callback + // fired) + if (current_longpress_char_.has_value() && + event.type() == ui::EventType::ET_KEY_RELEASED && + *current_longpress_char_ == c) { + current_longpress_char_ = absl::nullopt; + longpress_timer_.Stop(); + return; + } + } +} + +void AssistiveSuggester::OnLongpressDetected() { + if (!current_longpress_char_.has_value()) { + return; + } + longpress_diacritics_suggester_.TrySuggestOnLongpress( + *current_longpress_char_); + current_longpress_char_ = absl::nullopt; } void AssistiveSuggester::OnExternalSuggestionsUpdated( @@ -493,7 +559,8 @@ return; if (IsMultiWordSuggestEnabled()) { - // Only multi word cares about tracking the current state of the text field + // Only multi word cares about tracking the current state of the text + // field multi_word_suggester_.OnSurroundingTextChanged(text, cursor_pos, anchor_pos); }
diff --git a/chrome/browser/ash/input_method/assistive_suggester.h b/chrome/browser/ash/input_method/assistive_suggester.h index 14c0fe5..d9cdd39b 100644 --- a/chrome/browser/ash/input_method/assistive_suggester.h +++ b/chrome/browser/ash/input_method/assistive_suggester.h
@@ -11,8 +11,10 @@ #include "ash/services/ime/public/cpp/suggestions.h" #include "base/memory/weak_ptr.h" +#include "base/timer/timer.h" #include "chrome/browser/ash/input_method/assistive_suggester_switch.h" #include "chrome/browser/ash/input_method/emoji_suggester.h" +#include "chrome/browser/ash/input_method/longpress_diacritics_suggester.h" #include "chrome/browser/ash/input_method/multi_word_suggester.h" #include "chrome/browser/ash/input_method/personal_info_suggester.h" #include "chrome/browser/ash/input_method/suggester.h" @@ -74,7 +76,7 @@ int anchor_pos); // Called when the user pressed a key. - // Returns true if suggester handles the event and it should stop propagate. + // Returns true if it should stop further processing of event. bool OnKeyEvent(const ui::KeyEvent& event); // Called when suggestions are generated outside of the assistive framework. @@ -163,10 +165,15 @@ void RecordTextInputStateMetrics( const AssistiveSuggesterSwitch::EnabledSuggestions& enabled_suggestions); + void HandleLongpressEnabledKeyEvent(const ui::KeyEvent& key_character); + + void OnLongpressDetected(); + Profile* profile_; PersonalInfoSuggester personal_info_suggester_; EmojiSuggester emoji_suggester_; MultiWordSuggester multi_word_suggester_; + LongpressDiacriticsSuggester longpress_diacritics_suggester_; std::unique_ptr<AssistiveSuggesterSwitch> suggester_switch_; // The id of the currently active input engine. @@ -175,6 +182,13 @@ // ID of the focused text field, nullopt if none focused. absl::optional<int> focused_context_id_; + // Char of the currently held down key. nullopt if no longpress in progress. + absl::optional<char> current_longpress_char_; + + // Timer for longpress. Starts when key is held down. Fires when successfully + // held down for a specified longpress duration. + base::OneShotTimer longpress_timer_; + // The current suggester in use, nullptr means no suggestion is shown. Suggester* current_suggester_ = nullptr;
diff --git a/chrome/browser/ash/input_method/assistive_suggester_unittest.cc b/chrome/browser/ash/input_method/assistive_suggester_unittest.cc index d61001d..ce3f2f79 100644 --- a/chrome/browser/ash/input_method/assistive_suggester_unittest.cc +++ b/chrome/browser/ash/input_method/assistive_suggester_unittest.cc
@@ -9,6 +9,7 @@ #include "base/strings/utf_string_conversions.h" #include "base/test/metrics/histogram_tester.h" #include "base/test/scoped_feature_list.h" +#include "base/time/time.h" #include "chrome/browser/ash/input_method/assistive_suggester_client_filter.h" #include "chrome/browser/ash/input_method/assistive_suggester_switch.h" #include "chrome/browser/ash/input_method/fake_suggestion_handler.h" @@ -49,6 +50,10 @@ ui::DomKey::NONE, ui::EventTimeForNow()); } +ui::KeyEvent ReleaseKey(const ui::DomCode& code) { + return GenerateKeyEvent(code, ui::EventType::ET_KEY_RELEASED, ui::EF_NONE); +} + ui::KeyEvent PressKey(const ui::DomCode& code) { return GenerateKeyEvent(code, ui::EventType::ET_KEY_PRESSED, ui::EF_NONE); } @@ -67,6 +72,11 @@ ui::EF_SHIFT_DOWN); } +ui::KeyEvent CreateRepeatKeyEvent(const ui::DomCode& code) { + return GenerateKeyEvent(code, ui::EventType::ET_KEY_PRESSED, + ui::EF_IS_REPEAT); +} + void SetInputMethodOptions(Profile& profile, bool predictive_writing_enabled) { base::Value input_method_setting(base::Value::Type::DICTIONARY); input_method_setting.SetPath(std::string(kUsEnglishEngineId) + @@ -114,7 +124,8 @@ profile_->GetPrefs()->SetBoolean(prefs::kEmojiSuggestionEnabled, false); } - content::BrowserTaskEnvironment task_environment_; + content::BrowserTaskEnvironment task_environment_{ + base::test::TaskEnvironment::TimeSource::MOCK_TIME}; std::unique_ptr<TestingProfile> profile_; std::unique_ptr<AssistiveSuggester> assistive_suggester_; std::unique_ptr<FakeSuggestionHandler> suggestion_handler_; @@ -265,6 +276,15 @@ EXPECT_FALSE(assistive_suggester_->IsAssistiveFeatureEnabled()); } +TEST_F(AssistiveSuggesterTest, + AssistiveDiacriticsLongpressFlagEnabled_AssistiveFeatureEnabled) { + base::test::ScopedFeatureList feature_list; + feature_list.InitAndEnableFeature( + features::kDiacriticsOnPhysicalKeyboardLongpress); + + EXPECT_TRUE(assistive_suggester_->IsAssistiveFeatureEnabled()); +} + TEST_F(AssistiveSuggesterTest, RecordPredictiveWritingPrefOnActivate) { base::test::ScopedFeatureList feature_list; feature_list.InitWithFeatures( @@ -387,6 +407,101 @@ AssistiveTextInputState::kFeatureEnabled, 1); } +TEST_F(AssistiveSuggesterTest, DiacriticsSugestionOnKeyDownLongpress) { + base::test::ScopedFeatureList feature_list; + feature_list.InitWithFeatures( + /*enabled_features=*/{features::kDiacriticsOnPhysicalKeyboardLongpress}, + /*disabled_features=*/{}); + assistive_suggester_->OnActivate(kUsEnglishEngineId); + assistive_suggester_->OnFocus(5); + + EXPECT_FALSE(assistive_suggester_->OnKeyEvent(PressKey(ui::DomCode::US_A))); + task_environment_.FastForwardBy(base::Seconds(1)); + + EXPECT_TRUE(suggestion_handler_->GetShowingSuggestion()); + EXPECT_EQ(suggestion_handler_->GetSuggestionText(), u"à;á;â;ã;ã;ä;å;ā"); +} + +TEST_F(AssistiveSuggesterTest, + DiacriticsSugestionOnKeyDownLongpressNotInterruptedByOtherKeys) { + base::test::ScopedFeatureList feature_list; + feature_list.InitWithFeatures( + /*enabled_features=*/{features::kDiacriticsOnPhysicalKeyboardLongpress}, + /*disabled_features=*/{}); + assistive_suggester_->OnActivate(kUsEnglishEngineId); + assistive_suggester_->OnFocus(5); + + EXPECT_FALSE(assistive_suggester_->OnKeyEvent(PressKey(ui::DomCode::US_A))); + EXPECT_FALSE(assistive_suggester_->OnKeyEvent(PressKey(ui::DomCode::US_O))); + EXPECT_FALSE(assistive_suggester_->OnKeyEvent(ReleaseKey(ui::DomCode::US_O))); + task_environment_.FastForwardBy(base::Seconds(1)); + EXPECT_TRUE(suggestion_handler_->GetShowingSuggestion()); + EXPECT_EQ(suggestion_handler_->GetSuggestionText(), u"à;á;â;ã;ã;ä;å;ā"); +} + +TEST_F(AssistiveSuggesterTest, + DiacriticsSugestionWithoutContextIgnoresOnKeyDownLongpress) { + base::test::ScopedFeatureList feature_list; + feature_list.InitWithFeatures( + /*enabled_features=*/{features::kDiacriticsOnPhysicalKeyboardLongpress}, + /*disabled_features=*/{}); + assistive_suggester_->OnActivate(kUsEnglishEngineId); + + EXPECT_FALSE(assistive_suggester_->OnKeyEvent(PressKey(ui::DomCode::US_A))); + task_environment_.FastForwardBy(base::Seconds(1)); + EXPECT_FALSE(suggestion_handler_->GetShowingSuggestion()); + EXPECT_EQ(suggestion_handler_->GetSuggestionText(), u""); +} + +TEST_F(AssistiveSuggesterTest, DiacriticsSugestionInterruptedDoesNotSuggest) { + base::test::ScopedFeatureList feature_list; + feature_list.InitWithFeatures( + /*enabled_features=*/{features::kDiacriticsOnPhysicalKeyboardLongpress}, + /*disabled_features=*/{}); + assistive_suggester_->OnActivate(kUsEnglishEngineId); + assistive_suggester_->OnFocus(5); + + EXPECT_FALSE(assistive_suggester_->OnKeyEvent(PressKey(ui::DomCode::US_A))); + task_environment_.FastForwardBy( + base::Milliseconds(100)); // Not long enough to trigger longpress. + EXPECT_FALSE(assistive_suggester_->OnKeyEvent(ReleaseKey(ui::DomCode::US_A))); + EXPECT_FALSE(suggestion_handler_->GetShowingSuggestion()); + EXPECT_EQ(suggestion_handler_->GetSuggestionText(), u""); +} + +TEST_F(AssistiveSuggesterTest, + ProcesssAndDoNotPropagateAlphaRepeatKeyIfDiacriticsOnLongpressEnabled) { + base::test::ScopedFeatureList feature_list; + feature_list.InitWithFeatures( + /*enabled_features=*/{features::kDiacriticsOnPhysicalKeyboardLongpress}, + /*disabled_features=*/{}); + assistive_suggester_->OnActivate(kUsEnglishEngineId); + assistive_suggester_->OnFocus(5); + + // Returning true tells IME to not propagate this event. + EXPECT_TRUE(assistive_suggester_->OnKeyEvent( + CreateRepeatKeyEvent(ui::DomCode::US_A))); + task_environment_.FastForwardBy( + base::Seconds(1)); // Long enough to trigger longpress. + EXPECT_FALSE(assistive_suggester_->OnKeyEvent(ReleaseKey(ui::DomCode::US_A))); + EXPECT_FALSE(suggestion_handler_->GetShowingSuggestion()); + EXPECT_EQ(suggestion_handler_->GetSuggestionText(), u""); +} + +TEST_F(AssistiveSuggesterTest, + IgnoreAndPropagateNonAlphaRepeatKeyIfDiacriticsOnLongpressEnabled) { + base::test::ScopedFeatureList feature_list; + feature_list.InitWithFeatures( + /*enabled_features=*/{features::kDiacriticsOnPhysicalKeyboardLongpress}, + /*disabled_features=*/{}); + assistive_suggester_->OnActivate(kUsEnglishEngineId); + assistive_suggester_->OnFocus(5); + + // Returning false tells IME to propagate this event. + EXPECT_FALSE(assistive_suggester_->OnKeyEvent( + CreateRepeatKeyEvent(ui::DomCode::ARROW_DOWN))); +} + struct PersonalInfoTestCase { std::string test_name; std::u16string surrounding_text;
diff --git a/chrome/browser/ash/notifications/request_pin_view.cc b/chrome/browser/ash/notifications/request_pin_view.cc index d186643..c2c55f98 100644 --- a/chrome/browser/ash/notifications/request_pin_view.cc +++ b/chrome/browser/ash/notifications/request_pin_view.cc
@@ -12,7 +12,7 @@ #include "base/i18n/number_formatting.h" #include "base/strings/string_util.h" #include "base/strings/utf_string_conversions.h" -#include "chrome/browser/ash/notifications/passphrase_textfield.h" +#include "chrome/browser/notifications/passphrase_textfield.h" #include "chrome/browser/ui/views/chrome_layout_provider.h" #include "chrome/browser/ui/views/chrome_typography.h" #include "chrome/grit/generated_resources.h" @@ -143,6 +143,10 @@ UpdateHeaderText(); } +bool RequestPinView::IsTextStyleOfErrorLabelCorrectForTesting() const { + return STYLE_RED == error_label_->GetTextStyle(); +} + void RequestPinView::UpdateHeaderText() { int label_text_id = IDS_REQUEST_PIN_DIALOG_HEADER; std::u16string label_text = @@ -173,7 +177,7 @@ views::LayoutAlignment::kStart); // Textfield to enter the PIN/PUK. - textfield_ = AddChildView(std::make_unique<ash::PassphraseTextfield>()); + textfield_ = AddChildView(std::make_unique<PassphraseTextfield>()); textfield_->set_controller(this); textfield_->SetEnabled(true); textfield_->SetAssociatedLabel(header_label_);
diff --git a/chrome/browser/ash/notifications/request_pin_view.h b/chrome/browser/ash/notifications/request_pin_view.h index 06bf038a..15aba1d 100644 --- a/chrome/browser/ash/notifications/request_pin_view.h +++ b/chrome/browser/ash/notifications/request_pin_view.h
@@ -82,8 +82,10 @@ // the header text displayed by the view. void SetExtensionName(const std::string& extension_name); + // Checking that the text style of `error_label_` is correct. + bool IsTextStyleOfErrorLabelCorrectForTesting() const; + views::Textfield* textfield_for_testing() { return textfield_; } - views::Label* error_label_for_testing() { return error_label_; } private: // This initializes the view, with all the UI components.
diff --git a/chrome/browser/ash/notifications/request_system_proxy_credentials_view.cc b/chrome/browser/ash/notifications/request_system_proxy_credentials_view.cc index e7448c0..0b88958 100644 --- a/chrome/browser/ash/notifications/request_system_proxy_credentials_view.cc +++ b/chrome/browser/ash/notifications/request_system_proxy_credentials_view.cc
@@ -12,7 +12,7 @@ #include "base/i18n/number_formatting.h" #include "base/strings/string_util.h" #include "base/strings/utf_string_conversions.h" -#include "chrome/browser/ash/notifications/passphrase_textfield.h" +#include "chrome/browser/notifications/passphrase_textfield.h" #include "chrome/browser/ui/browser_dialogs.h" #include "chrome/browser/ui/views/chrome_layout_provider.h" #include "chrome/grit/generated_resources.h" @@ -170,8 +170,8 @@ std::make_unique<views::Label>(l10n_util::GetStringUTF16( IDS_SYSTEM_PROXY_AUTH_DIALOG_PASSWORD_LABEL))); password_label->SetEnabled(true); - password_textfield_ = - auth_container->AddChildView(std::make_unique<PassphraseTextfield>()); + password_textfield_ = auth_container->AddChildView( + std::make_unique<chromeos::PassphraseTextfield>()); password_textfield_->SetEnabled(true); password_textfield_->SetAssociatedLabel(password_label); auth_container->AddPaddingRow(views::TableLayout::kFixedSize,
diff --git a/chrome/browser/ash/web_applications/personalization_app/personalization_app_keyboard_backlight_provider_impl.cc b/chrome/browser/ash/web_applications/personalization_app/personalization_app_keyboard_backlight_provider_impl.cc index 0b7462e..49ee8793 100644 --- a/chrome/browser/ash/web_applications/personalization_app/personalization_app_keyboard_backlight_provider_impl.cc +++ b/chrome/browser/ash/web_applications/personalization_app/personalization_app_keyboard_backlight_provider_impl.cc
@@ -35,7 +35,9 @@ PersonalizationAppKeyboardBacklightProviderImpl:: PersonalizationAppKeyboardBacklightProviderImpl(content::WebUI* web_ui) - : profile_(Profile::FromWebUI(web_ui)) {} + : profile_(Profile::FromWebUI(web_ui)) { + wallpaper_controller_observation_.Observe(WallpaperController::Get()); +} PersonalizationAppKeyboardBacklightProviderImpl:: ~PersonalizationAppKeyboardBacklightProviderImpl() = default; @@ -58,6 +60,9 @@ // Call it once to get the status of color preset. NotifyBacklightColorChanged(); + + // Call it once to get the wallpaper extracted color. + OnWallpaperColorsChanged(); } void PersonalizationAppKeyboardBacklightProviderImpl::SetBacklightColor( @@ -77,6 +82,14 @@ } void PersonalizationAppKeyboardBacklightProviderImpl:: + OnWallpaperColorsChanged() { + DCHECK(keyboard_backlight_observer_remote_.is_bound()); + keyboard_backlight_observer_remote_->OnWallpaperColorChanged( + ConvertBacklightColorToSkColor( + personalization_app::mojom::BacklightColor::kWallpaper)); +} + +void PersonalizationAppKeyboardBacklightProviderImpl:: NotifyBacklightColorChanged() { DCHECK(keyboard_backlight_observer_remote_.is_bound()); keyboard_backlight_observer_remote_->OnBacklightColorChanged(
diff --git a/chrome/browser/ash/web_applications/personalization_app/personalization_app_keyboard_backlight_provider_impl.h b/chrome/browser/ash/web_applications/personalization_app/personalization_app_keyboard_backlight_provider_impl.h index 34e82e9..cd520a4 100644 --- a/chrome/browser/ash/web_applications/personalization_app/personalization_app_keyboard_backlight_provider_impl.h +++ b/chrome/browser/ash/web_applications/personalization_app/personalization_app_keyboard_backlight_provider_impl.h
@@ -5,8 +5,11 @@ #ifndef CHROME_BROWSER_ASH_WEB_APPLICATIONS_PERSONALIZATION_APP_PERSONALIZATION_APP_KEYBOARD_BACKLIGHT_PROVIDER_IMPL_H_ #define CHROME_BROWSER_ASH_WEB_APPLICATIONS_PERSONALIZATION_APP_PERSONALIZATION_APP_KEYBOARD_BACKLIGHT_PROVIDER_IMPL_H_ +#include "ash/public/cpp/wallpaper/wallpaper_controller.h" +#include "ash/public/cpp/wallpaper/wallpaper_controller_observer.h" #include "ash/webui/personalization_app/mojom/personalization_app.mojom.h" #include "ash/webui/personalization_app/personalization_app_keyboard_backlight_provider.h" +#include "base/scoped_observation.h" #include "mojo/public/cpp/bindings/receiver.h" #include "mojo/public/cpp/bindings/remote.h" @@ -19,7 +22,8 @@ namespace ash::personalization_app { class PersonalizationAppKeyboardBacklightProviderImpl - : public PersonalizationAppKeyboardBacklightProvider { + : public PersonalizationAppKeyboardBacklightProvider, + public WallpaperControllerObserver { public: explicit PersonalizationAppKeyboardBacklightProviderImpl( content::WebUI* web_ui); @@ -45,6 +49,9 @@ void SetBacklightColor( ash::personalization_app::mojom::BacklightColor backlight_color) override; + // WallpaperControllerObserver: + void OnWallpaperColorsChanged() override; + private: // Notify webUI the current state of backlight color. void NotifyBacklightColorChanged(); @@ -57,6 +64,9 @@ mojo::Remote<ash::personalization_app::mojom::KeyboardBacklightObserver> keyboard_backlight_observer_remote_; + + base::ScopedObservation<WallpaperController, WallpaperControllerObserver> + wallpaper_controller_observation_{this}; }; } // namespace ash::personalization_app
diff --git a/chrome/browser/ash/web_applications/personalization_app/personalization_app_keyboard_backlight_provider_impl_unittest.cc b/chrome/browser/ash/web_applications/personalization_app/personalization_app_keyboard_backlight_provider_impl_unittest.cc index abd99ac2..0b660273 100644 --- a/chrome/browser/ash/web_applications/personalization_app/personalization_app_keyboard_backlight_provider_impl_unittest.cc +++ b/chrome/browser/ash/web_applications/personalization_app/personalization_app_keyboard_backlight_provider_impl_unittest.cc
@@ -15,6 +15,7 @@ #include "content/public/browser/web_contents.h" #include "content/public/test/test_web_ui.h" #include "testing/gtest/include/gtest/gtest.h" +#include "third_party/skia/include/core/SkColor.h" namespace ash { namespace personalization_app { @@ -30,6 +31,10 @@ backlight_color_ = backlight_color; } + void OnWallpaperColorChanged(SkColor wallpaper_color) override { + wallpaper_color_ = wallpaper_color; + } + mojo::PendingRemote< ash::personalization_app::mojom::KeyboardBacklightObserver> pending_remote() { @@ -45,11 +50,17 @@ return backlight_color_; } + SkColor wallpaper_color() { + keyboard_backlight_observer_receiver_.FlushForTesting(); + return wallpaper_color_; + } + private: mojo::Receiver<ash::personalization_app::mojom::KeyboardBacklightObserver> keyboard_backlight_observer_receiver_{this}; mojom::BacklightColor backlight_color_ = mojom::BacklightColor::kWallpaper; + SkColor wallpaper_color_ = SK_ColorTRANSPARENT; }; } // namespace @@ -115,6 +126,11 @@ return test_keyboard_backlight_observer_.backlight_color(); } + SkColor ObservedWallpaperColor() { + keyboard_backlight_provider_remote_.FlushForTesting(); + return test_keyboard_backlight_observer_.wallpaper_color(); + } + private: base::test::ScopedFeatureList scoped_feature_list_; user_manager::ScopedUserManager scoped_user_manager_; @@ -140,5 +156,15 @@ EXPECT_EQ(mojom::BacklightColor::kBlue, ObservedBacklightColor()); } +TEST_F(PersonalizationAppKeyboardBacklightProviderImplTest, + ObserveWallpaperColor) { + SetKeyboardBacklightObserver(); + keyboard_backlight_provider_remote()->FlushForTesting(); + keyboard_backlight_provider()->OnWallpaperColorsChanged(); + + // Verify JS side is notified. + EXPECT_EQ(SK_ColorTRANSPARENT, ObservedWallpaperColor()); +} + } // namespace personalization_app } // namespace ash
diff --git a/chrome/browser/chromeos/BUILD.gn b/chrome/browser/chromeos/BUILD.gn index 0ef4609b..e273165 100644 --- a/chrome/browser/chromeos/BUILD.gn +++ b/chrome/browser/chromeos/BUILD.gn
@@ -1618,8 +1618,6 @@ "../ash/notifications/kiosk_external_update_notification.h", "../ash/notifications/low_disk_notification.cc", "../ash/notifications/low_disk_notification.h", - "../ash/notifications/passphrase_textfield.cc", - "../ash/notifications/passphrase_textfield.h", "../ash/notifications/request_pin_view.cc", "../ash/notifications/request_pin_view.h", "../ash/notifications/request_system_proxy_credentials_view.cc",
diff --git a/chrome/browser/client_hints/client_hints_browsertest.cc b/chrome/browser/client_hints/client_hints_browsertest.cc index 21749eb..b8193763 100644 --- a/chrome/browser/client_hints/client_hints_browsertest.cc +++ b/chrome/browser/client_hints/client_hints_browsertest.cc
@@ -56,7 +56,6 @@ #include "content/public/browser/navigation_entry.h" #include "content/public/browser/navigation_handle.h" #include "content/public/browser/render_view_host.h" -#include "content/public/browser/storage_partition.h" #include "content/public/browser/web_contents_observer.h" #include "content/public/common/content_features.h" #include "content/public/common/content_switches.h" @@ -67,7 +66,6 @@ #include "content/public/test/test_navigation_observer.h" #include "content/public/test/test_utils.h" #include "content/public/test/url_loader_interceptor.h" -#include "net/base/features.h" #include "net/dns/mock_host_resolver.h" #include "net/http/http_request_headers.h" #include "net/http/http_status_code.h" @@ -83,8 +81,6 @@ #include "services/network/public/cpp/features.h" #include "services/network/public/cpp/network_switches.h" #include "services/network/public/cpp/resource_request.h" -#include "services/network/public/mojom/cookie_manager.mojom.h" -#include "services/network/public/mojom/network_context.mojom.h" #include "services/network/public/mojom/web_client_hints_types.mojom-shared.h" #include "testing/gmock/include/gmock/gmock-matchers.h" #include "testing/gmock/include/gmock/gmock.h" @@ -1054,13 +1050,6 @@ continue; } - // Skip over the `Sec-CH-Partitioned-Cookies' client hint because it is - // only added in the presence of a valid "PartitionedCookies" Origin - // Trial token. - if (header == "sec-ch-partitioned-cookies") { - continue; - } - EXPECT_EQ(expect_client_hints, base::Contains(request.headers, header)); } } @@ -5180,830 +5169,6 @@ /*ch_ua_exist_expected=*/true); } -// TODO(crbug.com/1296161): Delete this when the partitioned cookies Origin -// Trial is over. -class PartitionedCookiesOriginTrialBrowserTest : public InProcessBrowserTest { - public: - void SetUpCommandLine(base::CommandLine* command_line) override { - // The public key for the default privatey key used by the - // tools/origin_trials/generate_token.py tool. - static constexpr char kOriginTrialTestPublicKey[] = - "dRCs+TocuKkocNKa0AtZ4awrt9XKH2SQCI6o4FY6BNA="; - command_line->AppendSwitchASCII(embedder_support::kOriginTrialPublicKey, - kOriginTrialTestPublicKey); - } - - void SetUp() override { - scoped_feature_list_.InitWithFeatureList(EnabledFeatures()); - InProcessBrowserTest::SetUp(); - } - - void TearDownOnMainThread() override { - url_loader_interceptor_.reset(); - InProcessBrowserTest::TearDownOnMainThread(); - } - - network::mojom::CookieManager* GetCookieManager() { - return browser() - ->profile() - ->GetDefaultStoragePartition() - ->GetCookieManagerForBrowserProcess(); - } - - void SetCookie(const std::string& name, - const std::string& value, - const GURL& url, - const absl::optional<net::CookiePartitionKey>& partition_key) { - auto cookie = net::CanonicalCookie::CreateUnsafeCookieForTesting( - name, value, url.host(), "/", base::Time::Now() - base::Days(1), - base::Time::Now() + base::Days(1), base::Time::Now(), base::Time::Now(), - /*secure=*/true, /*httponly=*/false, - net::CookieSameSite::NO_RESTRICTION, - net::CookiePriority::COOKIE_PRIORITY_DEFAULT, /*same_party=*/false, - partition_key); - EXPECT_TRUE(cookie->IsCanonical()); - - base::RunLoop run_loop; - GetCookieManager()->SetCanonicalCookie( - *cookie, url, net::CookieOptions::MakeAllInclusive(), - base::BindLambdaForTesting( - [&](net::CookieAccessResult set_cookie_result) { - EXPECT_TRUE(set_cookie_result.status.IsInclude()); - run_loop.Quit(); - })); - run_loop.Run(); - } - - std::vector<net::CanonicalCookie> GetCookies(const GURL& url) { - std::vector<net::CanonicalCookie> cookies; - - base::RunLoop run_loop; - GetCookieManager()->GetCookieList( - url, net::CookieOptions::MakeAllInclusive(), - net::CookiePartitionKeyCollection::ContainsAll(), - base::BindLambdaForTesting( - [&](const std::vector<::net::CookieWithAccessResult>& result, - const std::vector<::net::CookieWithAccessResult>& - excluded_cookies) { - EXPECT_TRUE(excluded_cookies.empty()); - for (const auto& el : result) { - cookies.push_back(el.cookie); - } - run_loop.Quit(); - })); - run_loop.Run(); - - return cookies; - } - - void SetTestOptions(const OriginTrialTestOptions& test_setting, - const std::set<GURL>& expected_request_urls) { - test_options_ = test_setting; - expected_request_urls_ = expected_request_urls; - } - - void NavigateTo(const GURL& url) { - ASSERT_TRUE(ui_test_utils::NavigateToURL(browser(), url)); - } - - void NavigateAndCheckClientHint(const GURL& url, - bool expects_hint_is_present, - bool expected_hint_value) { - NavigateTo(url); - auto header_value = GetLastPartitionedCookiesClientHintValue(); - if (expects_hint_is_present) { - EXPECT_THAT(header_value, - Optional(Eq(expected_hint_value ? "?1" : "?0"))); - } else { - EXPECT_THAT(header_value, Eq(absl::nullopt)); - } - } - - void NavigateTwiceAndCheckClientHint(const GURL& url, - bool expects_hint_is_present, - bool expected_hint_value) { - NavigateTo(url); - NavigateAndCheckClientHint(url, expects_hint_is_present, - expected_hint_value); - } - - absl::optional<std::string> GetLastPartitionedCookiesClientHintValue() { - std::string header_value; - if (url_loader_interceptor_->GetLastRequestHeaders().GetHeader( - "sec-ch-partitioned-cookies", &header_value)) { - return header_value; - } - return absl::nullopt; - } - - protected: - virtual std::string BuildOriginTrialHeader() const { return ""; } - - virtual std::unique_ptr<base::FeatureList> EnabledFeatures() { - std::unique_ptr<base::FeatureList> feature_list(new base::FeatureList); - feature_list->InitializeFromCommandLine( - "UserAgentClientHint,CriticalClientHint,AcceptCHFrame," - "PartitionedCookies", - ""); - return feature_list; - } - - OriginTrialTestOptions test_options_; - std::set<GURL> expected_request_urls_; - std::unique_ptr<URLLoaderInterceptor> url_loader_interceptor_; - base::test::ScopedFeatureList scoped_feature_list_; -}; - -// Tests that verify Sec-CH-Partitioned-Cookies client hint is sent if and only -// if the PartitionedCookies Origin Trial token is present and valid in the -// response headers. -// -// The test Origin Trial token was generated by running: -// python tools/origin_trials/generate_token.py https://127.0.0.1:44444 \ -// PartitionedCookies \ -// --expire-timestamp=2000000000 -// -// The Origin Trial token expires in 2033. Generate a new token by then, or -// find a better way to re-generate a test trial token. -class SameOriginPartitionedCookiesOriginTrialBrowserTest - : public PartitionedCookiesOriginTrialBrowserTest { - public: - SameOriginPartitionedCookiesOriginTrialBrowserTest() = default; - - void SetUpOnMainThread() override { - // We use a URLLoaderInterceptor, rather than the EmbeddedTestServer, since - // the origin trial token in the response is associated with a fixed - // origin, whereas EmbeddedTestServer serves content on a random port. - url_loader_interceptor_ = std::make_unique< - URLLoaderInterceptor>(base::BindRepeating( - &SameOriginPartitionedCookiesOriginTrialBrowserTest::InterceptRequest, - base::Unretained(this))); - InProcessBrowserTest::SetUpOnMainThread(); - } - - // The URL that was used to register the Origin Trial token. - static constexpr const char kOriginUrl[] = "https://127.0.0.1:44444"; - - GURL partitioned_cookies_url() const { - return GURL( - base::StrCat({kOriginUrl, "/partitioned_cookies_same_origin.html"})); - } - - std::string BuildOriginTrialHeader() const override { - std::string headers; - - static constexpr const char kOriginTrialToken[] = - "A4s/" - "iPKfhEfgqQIIuz4zLuCpONpXOuYyJFBhBx1MfgS1aNhFujyhsg4lkfTRfjzQCI3aUbMwtN" - "m25elLTR4UIgAAAABceyJvcmlnaW4iOiAiaHR0cHM6Ly8xMjcuMC4wLjE6NDQ0NDQiLCAi" - "ZmVhdHVyZSI6ICJQYXJ0aXRpb25lZENvb2tpZXMiLCAiZXhwaXJ5IjogMjAwMDAwMDAwMH" - "0="; - - if (test_options_.has_accept_ch_header) { - base::StrAppend(&headers, - {"Accept-CH: ", "sec-ch-partitioned-cookies", "\n"}); - } - if (test_options_.has_critical_ch_header) { - base::StrAppend(&headers, - {"Critical-CH: ", "sec-ch-partitioned-cookies", "\n"}); - } - if (test_options_.has_ot_token) { - base::StrAppend( - &headers, - {"Origin-Trial: ", - test_options_.valid_ot_token ? kOriginTrialToken : "invalid", "\n"}); - } - - return headers; - } - - // URLLoaderInterceptor callback - bool InterceptRequest(URLLoaderInterceptor::RequestParams* params) { - if (expected_request_urls_.find(params->url_request.url) == - expected_request_urls_.end()) - return false; - - std::string path = "chrome/test/data/client_hints"; - path.append(static_cast<std::string>(params->url_request.url.path_piece())); - - std::string headers = "HTTP/1.1 200 OK\nContent-Type: text/html\n"; - base::StrAppend(&headers, {BuildOriginTrialHeader()}); - URLLoaderInterceptor::WriteResponse(path, params->client.get(), &headers, - absl::nullopt, - /*url=*/params->url_request.url); - return true; - } -}; - -IN_PROC_BROWSER_TEST_F(SameOriginPartitionedCookiesOriginTrialBrowserTest, - ValidTokenAndHeaderPresent) { - SetTestOptions( - {/*has_ot_token=*/true, /*valid_ot_token=*/true, - /*has_accept_ch_header=*/true, /*has_critical_ch_header=*/false}, - {partitioned_cookies_url()}); - - NavigateTwiceAndCheckClientHint(partitioned_cookies_url(), true, true); -} - -IN_PROC_BROWSER_TEST_F(SameOriginPartitionedCookiesOriginTrialBrowserTest, - NoTokenPresent) { - SetTestOptions( - {/*has_ot_token=*/false, /*valid_ot_token=*/true, - /*has_accept_ch_header=*/true, /*has_critical_ch_header=*/false}, - {partitioned_cookies_url()}); - - NavigateTwiceAndCheckClientHint(partitioned_cookies_url(), false, false); -} - -IN_PROC_BROWSER_TEST_F(SameOriginPartitionedCookiesOriginTrialBrowserTest, - InvalidToken) { - SetTestOptions( - {/*has_ot_token=*/true, /*valid_ot_token=*/false, - /*has_accept_ch_header=*/true, /*has_critical_ch_header=*/false}, - {partitioned_cookies_url()}); - - NavigateTwiceAndCheckClientHint(partitioned_cookies_url(), false, false); -} - -IN_PROC_BROWSER_TEST_F(SameOriginPartitionedCookiesOriginTrialBrowserTest, - NoAcceptChHeader) { - SetTestOptions( - {/*has_ot_token=*/true, /*valid_ot_token=*/true, - /*has_accept_ch_header=*/false, /*has_critical_ch_header=*/false}, - {partitioned_cookies_url()}); - - NavigateTwiceAndCheckClientHint(partitioned_cookies_url(), false, false); -} - -IN_PROC_BROWSER_TEST_F(SameOriginPartitionedCookiesOriginTrialBrowserTest, - ConvertsPartitionedCookieOnOptOut_NoToken) { - // First check on the first request we still send the cookie and the - // Sec-CH-Partitioned-Cookies header in the false state. - SetTestOptions( - {/*has_ot_token=*/false, /*valid_ot_token=*/false, - /*has_accept_ch_header=*/true, /*has_critical_ch_header=*/false}, - {partitioned_cookies_url()}); - - SetCookie("__Host-A", "0", GURL(kOriginUrl), - net::CookiePartitionKey::FromURLForTesting(GURL(kOriginUrl))); - - auto cookies = GetCookies(GURL(kOriginUrl)); - EXPECT_EQ(cookies.size(), 1u); - EXPECT_EQ(cookies[0].Name(), "__Host-A"); - EXPECT_TRUE(cookies[0].IsPartitioned()); - - NavigateTo(partitioned_cookies_url()); - - cookies = GetCookies(GURL(kOriginUrl)); - EXPECT_EQ(cookies.size(), 1u); - EXPECT_EQ(cookies[0].Name(), "__Host-A"); - EXPECT_FALSE(cookies[0].IsPartitioned()); -} - -IN_PROC_BROWSER_TEST_F(SameOriginPartitionedCookiesOriginTrialBrowserTest, - ConvertsPartitionedCookieOnOptOut_InvalidToken) { - // First check on the first request we still send the cookie and the - // Sec-CH-Partitioned-Cookies header in the false state. - SetTestOptions( - {/*has_ot_token=*/true, /*valid_ot_token=*/false, - /*has_accept_ch_header=*/true, /*has_critical_ch_header=*/false}, - {partitioned_cookies_url()}); - - SetCookie("__Host-A", "0", GURL(kOriginUrl), - net::CookiePartitionKey::FromURLForTesting(GURL(kOriginUrl))); - - auto cookies = GetCookies(GURL(kOriginUrl)); - EXPECT_EQ(cookies.size(), 1u); - EXPECT_EQ(cookies[0].Name(), "__Host-A"); - EXPECT_TRUE(cookies[0].IsPartitioned()); - - NavigateTo(partitioned_cookies_url()); - - cookies = GetCookies(GURL(kOriginUrl)); - EXPECT_EQ(cookies.size(), 1u); - EXPECT_EQ(cookies[0].Name(), "__Host-A"); - EXPECT_FALSE(cookies[0].IsPartitioned()); -} - -IN_PROC_BROWSER_TEST_F(SameOriginPartitionedCookiesOriginTrialBrowserTest, - ConvertsPartitionedCookieOnOptOut_NoAcceptCh) { - // First check on the first request we still send the cookie and the - // Sec-CH-Partitioned-Cookies header in the false state. - SetTestOptions( - {/*has_ot_token=*/true, /*valid_ot_token=*/true, - /*has_accept_ch_header=*/false, /*has_critical_ch_header=*/false}, - {partitioned_cookies_url()}); - - SetCookie("__Host-A", "0", GURL(kOriginUrl), - net::CookiePartitionKey::FromURLForTesting(GURL(kOriginUrl))); - - auto cookies = GetCookies(GURL(kOriginUrl)); - EXPECT_EQ(cookies.size(), 1u); - EXPECT_EQ(cookies[0].Name(), "__Host-A"); - EXPECT_TRUE(cookies[0].IsPartitioned()); - - NavigateTo(partitioned_cookies_url()); - - cookies = GetCookies(GURL(kOriginUrl)); - EXPECT_EQ(cookies.size(), 1u); - EXPECT_EQ(cookies[0].Name(), "__Host-A"); - EXPECT_FALSE(cookies[0].IsPartitioned()); -} - -// Tests that verify Sec-CH-Partitioned-Cookies client hint is sent if and only -// if the PartitionedCookies Origin Trial token is present and valid in the -// response headers. -// This test is specifically to exercise that third-party embeds will get the -// client hint if the top-level origin is opted into the trial. -// -// The test Origin Trial token was generated by running: -// python tools/origin_trials/generate_token.py https://my-site.com:44444 \ -// PartitionedCookies \ -// --expire-timestamp=2000000000 -// -// The Origin Trial token expires in 2033. Generate a new token by then, or -// find a better way to re-generate a test trial token. -class ThirdPartyPartitionedCookiesOriginTrialBrowserTest - : public PartitionedCookiesOriginTrialBrowserTest { - public: - ThirdPartyPartitionedCookiesOriginTrialBrowserTest() - : https_server_(net::EmbeddedTestServer::TYPE_HTTPS) { - https_server_.ServeFilesFromSourceDirectory( - "chrome/test/data/client_hints"); - https_server_.RegisterRequestMonitor(base::BindRepeating( - &ThirdPartyPartitionedCookiesOriginTrialBrowserTest:: - MonitorResourceRequest, - base::Unretained(this))); - EXPECT_TRUE(https_server_.Start()); - } - - void SetUpOnMainThread() override { - // We use a URLLoaderInterceptor, rather than the EmbeddedTestServer, since - // the origin trial token in the response is associated with a fixed - // origin, whereas EmbeddedTestServer serves content on a random port. - url_loader_interceptor_ = std::make_unique< - URLLoaderInterceptor>(base::BindRepeating( - &ThirdPartyPartitionedCookiesOriginTrialBrowserTest::InterceptRequest, - base::Unretained(this))); - InProcessBrowserTest::SetUpOnMainThread(); - } - - // The URL that was used to register the Origin Trial token as the first - // party. Requests to this origin should be handled by URLLoader interceptor. - static constexpr const char kFirstPartyOriginUrl[] = - "https://my-site.com:44444"; - - // The URL of the site receiving cookies. - // Requests to this origin should be handled by the test server. - static constexpr char kCookieOriginUrlNoPort[] = "https://127.0.0.1:"; - - GURL partitioned_cookies_url() const { - return GURL(base::StrCat({kCookieOriginUrlNoPort, - base::NumberToString(https_server_.port()), - "/partitioned_cookies_embeddee.html"})); - } - - GURL origin_trial_participant_url() const { - return GURL(base::StrCat( - {kFirstPartyOriginUrl, "/partitioned_cookies_embedder.html"})); - } - - GURL last_requested_url() { - base::AutoLock lock(last_request_lock_); - return last_requested_url_; - } - - absl::optional<std::string> last_sec_ch_partitioned_cookies_value() { - base::AutoLock lock(last_request_lock_); - return last_sec_ch_partitioned_cookies_value_; - } - - // Called by `https_server_`. - void MonitorResourceRequest(const net::test_server::HttpRequest& request) { - base::AutoLock lock(last_request_lock_); - last_requested_url_ = request.GetURL(); - const auto& it = request.headers.find("sec-ch-partitioned-cookies"); - last_sec_ch_partitioned_cookies_value_ = - it != request.headers.end() ? absl::make_optional(it->second) - : absl::nullopt; - } - - // URLLoaderInterceptor callback - bool InterceptRequest(URLLoaderInterceptor::RequestParams* params) { - if (expected_request_urls_.find(params->url_request.url) == - expected_request_urls_.end()) - return false; - - if (params->url_request.url.path() == - base::StrCat({"/partitioned_cookies_embedder.html"})) { - std::string headers = "HTTP/1.1 200 OK\nContent-Type: text/html\n"; - base::StrAppend(&headers, {BuildOriginTrialHeader()}); - std::string body = "<html><head>"; - base::StrAppend(&body, {"</head><body>"}); - base::StrAppend(&body, {BuildIframeHTML()}); - base::StrAppend(&body, {"</body></html>"}); - URLLoaderInterceptor::WriteResponse(headers, body, params->client.get()); - return true; - } - - NOTREACHED(); - return false; - } - - private: - std::string BuildOriginTrialHeader() const override { - std::string headers; - - // The test Origin Trial token was generated by running: - // python tools/origin_trials/generate_token.py https://my-site.com:44444 \ - // PartitionedCookies \ - // --expire-timestamp=2000000000 - // - static constexpr const char kOriginTrialToken[] = - "A56J4whdQCcxi5r8mpiT1kXOUobK2NMpZmYtJaT5HD/" - "uDBtZgrVipOJhhp4VDL37SA4l9ve6dyZCs5Gr/" - "mEuGQcAAABeeyJvcmlnaW4iOiAiaHR0cHM6Ly9teS1zaXRlLmNvbTo0NDQ0NCIsICJmZWF" - "0dXJlIjogIlBhcnRpdGlvbmVkQ29va2llcyIsICJleHBpcnkiOiAyMDAwMDAwMDAwfQ=="; - - if (test_options_.has_accept_ch_header) { - base::StrAppend(&headers, - {"Accept-CH: ", "sec-ch-partitioned-cookies", "\n"}); - } - if (test_options_.has_critical_ch_header) { - base::StrAppend(&headers, - {"Critical-CH: ", "sec-ch-partitioned-cookies", "\n"}); - } - if (test_options_.has_ot_token) { - base::StrAppend( - &headers, - {"Origin-Trial: ", - test_options_.valid_ot_token ? kOriginTrialToken : "invalid", "\n"}); - } - - return headers; - } - - std::string BuildIframeHTML() { - std::string html = "<iframe src=\""; - base::StrAppend( - &html, - {https_server_.GetURL("/partitioned_cookies_embeddee.html").spec(), - "\"></iframe>"}); - return html; - } - - net::EmbeddedTestServer https_server_; - base::Lock last_request_lock_; - GURL last_requested_url_; - absl::optional<std::string> last_sec_ch_partitioned_cookies_value_; -}; - -IN_PROC_BROWSER_TEST_F(ThirdPartyPartitionedCookiesOriginTrialBrowserTest, - ValidTokenAndHeaderPresent) { - SetTestOptions( - {/*has_ot_token=*/true, /*valid_ot_token=*/true, - /*has_accept_ch_header=*/true, /*has_critical_ch_header=*/false}, - {origin_trial_participant_url()}); - - NavigateTo(origin_trial_participant_url()); - - EXPECT_EQ(last_requested_url(), partitioned_cookies_url()); - EXPECT_EQ(last_sec_ch_partitioned_cookies_value(), "?1"); -} - -IN_PROC_BROWSER_TEST_F(ThirdPartyPartitionedCookiesOriginTrialBrowserTest, - InvalidToken) { - SetTestOptions( - {/*has_ot_token=*/true, /*valid_ot_token=*/false, - /*has_accept_ch_header=*/true, /*has_critical_ch_header=*/false}, - {origin_trial_participant_url()}); - - NavigateTo(origin_trial_participant_url()); - - EXPECT_EQ(last_requested_url(), partitioned_cookies_url()); - EXPECT_EQ(last_sec_ch_partitioned_cookies_value(), absl::nullopt); -} - -IN_PROC_BROWSER_TEST_F(ThirdPartyPartitionedCookiesOriginTrialBrowserTest, - NoToken) { - SetTestOptions( - {/*has_ot_token=*/false, /*valid_ot_token=*/false, - /*has_accept_ch_header=*/true, /*has_critical_ch_header=*/false}, - {origin_trial_participant_url()}); - - NavigateTo(origin_trial_participant_url()); - - EXPECT_EQ(last_requested_url(), partitioned_cookies_url()); - EXPECT_EQ(last_sec_ch_partitioned_cookies_value(), absl::nullopt); -} - -IN_PROC_BROWSER_TEST_F(ThirdPartyPartitionedCookiesOriginTrialBrowserTest, - NoAcceptChHeader) { - SetTestOptions( - {/*has_ot_token=*/true, /*valid_ot_token=*/true, - /*has_accept_ch_header=*/false, /*has_critical_ch_header=*/false}, - {origin_trial_participant_url()}); - - NavigateTo(origin_trial_participant_url()); - - EXPECT_EQ(last_requested_url(), partitioned_cookies_url()); - EXPECT_EQ(last_sec_ch_partitioned_cookies_value(), absl::nullopt); -} - -IN_PROC_BROWSER_TEST_F(ThirdPartyPartitionedCookiesOriginTrialBrowserTest, - ConvertsPartitionedCookieOnOptOut_NoToken) { - // First check on the first request we still send the cookie and the - // Sec-CH-Partitioned-Cookies header in the false state. - SetTestOptions( - {/*has_ot_token=*/false, /*valid_ot_token=*/false, - /*has_accept_ch_header=*/true, /*has_critical_ch_header=*/false}, - {origin_trial_participant_url()}); - - SetCookie( - "__Host-A", "0", partitioned_cookies_url(), - net::CookiePartitionKey::FromURLForTesting(GURL(kFirstPartyOriginUrl))); - - auto cookies = GetCookies(partitioned_cookies_url()); - EXPECT_EQ(cookies.size(), 1u); - EXPECT_EQ(cookies[0].Name(), "__Host-A"); - EXPECT_TRUE(cookies[0].IsPartitioned()); - - NavigateTo(origin_trial_participant_url()); - // Can only test this header is present when using https_server_ because it is - // added by the network service. - EXPECT_EQ(last_sec_ch_partitioned_cookies_value(), "?0"); - - EXPECT_EQ(last_requested_url(), partitioned_cookies_url()); - - cookies = GetCookies(partitioned_cookies_url()); - EXPECT_EQ(cookies.size(), 1u); - EXPECT_EQ(cookies[0].Name(), "__Host-A"); - EXPECT_FALSE(cookies[0].IsPartitioned()); -} - -IN_PROC_BROWSER_TEST_F(ThirdPartyPartitionedCookiesOriginTrialBrowserTest, - ConvertsPartitionedCookieOnOptOut_InvalidToken) { - // First check on the first request we still send the cookie and the - // Sec-CH-Partitioned-Cookies header in the false state. - SetTestOptions( - {/*has_ot_token=*/true, /*valid_ot_token=*/false, - /*has_accept_ch_header=*/true, /*has_critical_ch_header=*/false}, - {origin_trial_participant_url()}); - - SetCookie( - "__Host-A", "0", partitioned_cookies_url(), - net::CookiePartitionKey::FromURLForTesting(GURL(kFirstPartyOriginUrl))); - - auto cookies = GetCookies(partitioned_cookies_url()); - EXPECT_EQ(cookies.size(), 1u); - EXPECT_EQ(cookies[0].Name(), "__Host-A"); - EXPECT_TRUE(cookies[0].IsPartitioned()); - - NavigateTo(origin_trial_participant_url()); - // Can only test this header is present when using https_server_ because it is - // added by the network service. - EXPECT_EQ(last_sec_ch_partitioned_cookies_value(), "?0"); - - EXPECT_EQ(last_requested_url(), partitioned_cookies_url()); - - cookies = GetCookies(partitioned_cookies_url()); - EXPECT_EQ(cookies.size(), 1u); - EXPECT_EQ(cookies[0].Name(), "__Host-A"); - EXPECT_FALSE(cookies[0].IsPartitioned()); -} - -IN_PROC_BROWSER_TEST_F(ThirdPartyPartitionedCookiesOriginTrialBrowserTest, - ConvertsPartitionedCookieOnOptOut_NoAcceptChHeader) { - // First check on the first request we still send the cookie and the - // Sec-CH-Partitioned-Cookies header in the false state. - SetTestOptions( - {/*has_ot_token=*/true, /*valid_ot_token=*/true, - /*has_accept_ch_header=*/false, /*has_critical_ch_header=*/false}, - {origin_trial_participant_url()}); - - SetCookie( - "__Host-A", "0", partitioned_cookies_url(), - net::CookiePartitionKey::FromURLForTesting(GURL(kFirstPartyOriginUrl))); - - auto cookies = GetCookies(partitioned_cookies_url()); - EXPECT_EQ(cookies.size(), 1u); - EXPECT_EQ(cookies[0].Name(), "__Host-A"); - EXPECT_TRUE(cookies[0].IsPartitioned()); - - NavigateTo(origin_trial_participant_url()); - // Can only test this header is present when using https_server_ because it is - // added by the network service. - EXPECT_EQ(last_sec_ch_partitioned_cookies_value(), "?0"); - - EXPECT_EQ(last_requested_url(), partitioned_cookies_url()); - - cookies = GetCookies(partitioned_cookies_url()); - EXPECT_EQ(cookies.size(), 1u); - EXPECT_EQ(cookies[0].Name(), "__Host-A"); - EXPECT_FALSE(cookies[0].IsPartitioned()); -} - -class EmbeddedPartitionedCookiesOriginTrialBrowserTest - : public PartitionedCookiesOriginTrialBrowserTest { - public: - EmbeddedPartitionedCookiesOriginTrialBrowserTest() = default; - - void SetUpOnMainThread() override { - // We use a URLLoaderInterceptor, rather than the EmbeddedTestServer, since - // the origin trial token in the response is associated with a fixed - // origin, whereas EmbeddedTestServer serves content on a random port. - url_loader_interceptor_ = - std::make_unique<URLLoaderInterceptor>(base::BindRepeating( - &EmbeddedPartitionedCookiesOriginTrialBrowserTest::InterceptRequest, - base::Unretained(this))); - InProcessBrowserTest::SetUpOnMainThread(); - } - - // URLLoaderInterceptor callback - bool InterceptRequest(URLLoaderInterceptor::RequestParams* params) { - if (expected_request_urls_.find(params->url_request.url) == - expected_request_urls_.end()) - return false; - - if (params->url_request.url.path() == - base::StrCat({"/partitioned_cookies_embedder.html"})) { - std::string headers = "HTTP/1.1 200 OK\nContent-Type: text/html\n"; - std::string body = "<html><head>"; - base::StrAppend(&body, {"</head><body>"}); - base::StrAppend(&body, {BuildIframeHTML()}); - base::StrAppend(&body, {"</body></html>"}); - URLLoaderInterceptor::WriteResponse(headers, body, params->client.get()); - return true; - } - - if (params->url_request.url.path() == - base::StrCat({"/partitioned_cookies_embeddee.html"})) { - std::string headers = "HTTP/1.1 200 OK\nContent-Type: text/html\n"; - base::StrAppend(&headers, {BuildOriginTrialHeader()}); - URLLoaderInterceptor::WriteResponse( - "chrome/test/data/client_hints/partitioned_cookies_embeddee.html", - params->client.get(), &headers, absl::nullopt, - params->url_request.url); - return true; - } - - NOTREACHED(); - return false; - } - - GURL embedder_url() const { - return GURL(base::StrCat( - {kFirstPartyOriginUrl, "/partitioned_cookies_embedder.html"})); - } - - // In this test, the OT participant is the embedded site. - GURL origin_trial_participant_url() const { - return GURL( - base::StrCat({kCookieOriginUrl, "/partitioned_cookies_embeddee.html"})); - } - - // The URL that was used to register the Origin Trial token as the first - // party. Requests to this origin should be handled by URLLoader interceptor. - static constexpr const char kFirstPartyOriginUrl[] = - "https://my-site.com:44444"; - - // The URL of the site receiving cookies. - // Requests to this origin should be handled by the test server. - static constexpr char kCookieOriginUrl[] = "https://127.0.0.1:44444"; - - std::string BuildOriginTrialHeader() const override { - std::string headers; - - static constexpr const char kOriginTrialToken[] = - "A1mBOyrOKGAaaoT8mjM1qSNrOSrdDUa9WyqicVLlDGW3feIBSdWqSiHDAXUeKkGKaVqUiC" - "X8avwCM0gpG5LtxgAAAAByeyJvcmlnaW4iOiAiaHR0cHM6Ly8xMjcuMC4wLjE6NDQ0NDQi" - "LCAiZmVhdHVyZSI6ICJQYXJ0aXRpb25lZENvb2tpZXMiLCAiZXhwaXJ5IjogMjAwMDAwMD" - "AwMCwgImlzVGhpcmRQYXJ0eSI6IHRydWV9"; - - if (test_options_.has_accept_ch_header) { - base::StrAppend(&headers, - {"Accept-CH: ", "sec-ch-partitioned-cookies", "\n"}); - } - if (test_options_.has_critical_ch_header) { - base::StrAppend(&headers, - {"Critical-CH: ", "sec-ch-partitioned-cookies", "\n"}); - } - if (test_options_.has_ot_token) { - base::StrAppend( - &headers, - {"Origin-Trial: ", - test_options_.valid_ot_token ? kOriginTrialToken : "invalid", "\n"}); - } - - return headers; - } - - private: - std::string BuildIframeHTML() { - std::string html = "<iframe src=\""; - base::StrAppend(&html, - {kCookieOriginUrl, "/partitioned_cookies_embeddee.html", - "\"></iframe>"}); - return html; - } -}; - -IN_PROC_BROWSER_TEST_F(EmbeddedPartitionedCookiesOriginTrialBrowserTest, - ValidTokenAndHeaderPresent) { - SetTestOptions( - {/*has_ot_token=*/true, /*valid_ot_token=*/true, - /*has_accept_ch_header=*/true, /*has_critical_ch_header=*/false}, - {embedder_url(), origin_trial_participant_url()}); - - NavigateTwiceAndCheckClientHint(embedder_url(), true, true); -} - -IN_PROC_BROWSER_TEST_F(EmbeddedPartitionedCookiesOriginTrialBrowserTest, - InvalidToken) { - SetTestOptions( - {/*has_ot_token=*/false, /*valid_ot_token=*/true, - /*has_accept_ch_header=*/true, /*has_critical_ch_header=*/false}, - {embedder_url(), origin_trial_participant_url()}); - - NavigateTwiceAndCheckClientHint(embedder_url(), false, false); -} - -class PartitionedCookiesBypassOriginTrialBrowserTest - : public PartitionedCookiesOriginTrialBrowserTest { - public: - PartitionedCookiesBypassOriginTrialBrowserTest() - : https_server_(net::EmbeddedTestServer::TYPE_HTTPS) { - https_server_.ServeFilesFromSourceDirectory( - "chrome/test/data/client_hints"); - https_server_.RegisterRequestMonitor(base::BindRepeating( - &PartitionedCookiesBypassOriginTrialBrowserTest::MonitorResourceRequest, - base::Unretained(this))); - EXPECT_TRUE(https_server_.Start()); - } - - // The URL of the site receiving cookies. - // Requests to this origin should be handled by the test server. - static constexpr char kCookieOriginUrlNoPort[] = "https://127.0.0.1:"; - - GURL partitioned_cookies_url() const { - return GURL(base::StrCat({kCookieOriginUrlNoPort, - base::NumberToString(https_server_.port()), - "/partitioned_cookies_embeddee.html"})); - } - - absl::optional<std::string> last_sec_ch_partitioned_cookies_value() { - base::AutoLock lock(last_request_lock_); - return last_sec_ch_partitioned_cookies_value_; - } - - void MonitorResourceRequest(const net::test_server::HttpRequest& request) { - base::AutoLock lock(last_request_lock_); - const auto& it = request.headers.find("sec-ch-partitioned-cookies"); - last_sec_ch_partitioned_cookies_value_ = - it != request.headers.end() ? absl::make_optional(it->second) - : absl::nullopt; - } - - protected: - std::unique_ptr<base::FeatureList> EnabledFeatures() override { - std::unique_ptr<base::FeatureList> feature_list(new base::FeatureList); - feature_list->InitializeFromCommandLine( - "UserAgentClientHint,CriticalClientHint,AcceptCHFrame," - "PartitionedCookies,PartitionedCookiesBypassOriginTrial", - ""); - return feature_list; - } - - net::EmbeddedTestServer https_server_; - absl::optional<std::string> last_sec_ch_partitioned_cookies_value_; - base::Lock last_request_lock_; -}; - -IN_PROC_BROWSER_TEST_F(PartitionedCookiesBypassOriginTrialBrowserTest, - ShouldAllowCookiesWithoutToken) { - SetCookie( - "__Host-A", "0", partitioned_cookies_url(), - net::CookiePartitionKey::FromURLForTesting(partitioned_cookies_url())); - - auto cookies = GetCookies(partitioned_cookies_url()); - EXPECT_EQ(cookies.size(), 1u); - EXPECT_EQ(cookies[0].Name(), "__Host-A"); - EXPECT_TRUE(cookies[0].IsPartitioned()); - - NavigateTo(partitioned_cookies_url()); - - // Check that the partitioned cookie did not get converted to unpartitioned - // even though the site never opted into the origin trial. - cookies = GetCookies(partitioned_cookies_url()); - EXPECT_EQ(cookies.size(), 1u); - EXPECT_EQ(cookies[0].Name(), "__Host-A"); - EXPECT_TRUE(cookies[0].IsPartitioned()); - - // We will still send the client hint in the false state if there are - // partitioned cookies on the machine. - EXPECT_EQ(last_sec_ch_partitioned_cookies_value(), "?0"); -} - // CrOS multi-profiles implementation is too different for these tests. #if !BUILDFLAG(IS_CHROMEOS_ASH)
diff --git a/chrome/browser/dev_ui_browser_resources.grd b/chrome/browser/dev_ui_browser_resources.grd index b03040e..73341144 100644 --- a/chrome/browser/dev_ui_browser_resources.grd +++ b/chrome/browser/dev_ui_browser_resources.grd
@@ -41,11 +41,11 @@ <if expr="is_android or is_linux or chromeos_ash or chromeos_lacros"> <include name="IDR_SANDBOX_INTERNALS_HTML" file="resources\sandbox_internals\sandbox_internals.html" preprocess="true" type="BINDATA" /> - <include name="IDR_SANDBOX_INTERNALS_JS" file="resources\sandbox_internals\sandbox_internals.js" preprocess="true" type="BINDATA" /> + <include name="IDR_SANDBOX_INTERNALS_JS" file="${root_gen_dir}\chrome\browser\resources\sandbox_internals\tsc\sandbox_internals.js" use_base_dir="false" type="BINDATA" /> </if> <if expr="is_win"> <include name="IDR_SANDBOX_INTERNALS_HTML" file="resources\sandbox_internals\sandbox_internals.html" preprocess="true" type="BINDATA" /> - <include name="IDR_SANDBOX_INTERNALS_JS" file="resources\sandbox_internals\sandbox_internals_win.js" type="BINDATA" /> + <include name="IDR_SANDBOX_INTERNALS_JS" file="${root_gen_dir}\chrome\browser\resources\sandbox_internals\tsc\sandbox_internals_win.js" use_base_dir="false" type="BINDATA" /> </if> <include name="IDR_SITE_ENGAGEMENT_HTML" file="resources\engagement\site_engagement.html" type="BINDATA" />
diff --git a/chrome/browser/extensions/api/certificate_provider/certificate_provider_apitest.cc b/chrome/browser/extensions/api/certificate_provider/certificate_provider_apitest.cc index 74eb23b2..2596d7d 100644 --- a/chrome/browser/extensions/api/certificate_provider/certificate_provider_apitest.cc +++ b/chrome/browser/extensions/api/certificate_provider/certificate_provider_apitest.cc
@@ -572,9 +572,8 @@ ASSERT_TRUE(listener.WaitUntilSatisfied()); // Check that we have an error message displayed. - EXPECT_EQ( - gfx::kGoogleRed600, - GetActivePinDialogView()->error_label_for_testing()->GetEnabledColor()); + EXPECT_TRUE( + GetActivePinDialogView()->IsTextStyleOfErrorLabelCorrectForTesting()); } bool SendCommand(const std::string& command) {
diff --git a/chrome/browser/feature_guide/notifications/internal/feature_notification_guide_service_impl.cc b/chrome/browser/feature_guide/notifications/internal/feature_notification_guide_service_impl.cc index c1aa1ee..c430e911 100644 --- a/chrome/browser/feature_guide/notifications/internal/feature_notification_guide_service_impl.cc +++ b/chrome/browser/feature_guide/notifications/internal/feature_notification_guide_service_impl.cc
@@ -114,7 +114,7 @@ is_low_engaged_user_ = result.is_ready && result.segment.has_value() && result.segment.value() == - optimization_guide::proto::OptimizationTarget:: + segmentation_platform::proto::SegmentId:: OPTIMIZATION_TARGET_SEGMENTATION_CHROME_LOW_USER_ENGAGEMENT; }
diff --git a/chrome/browser/feature_guide/notifications/internal/feature_notification_guide_service_impl_unittest.cc b/chrome/browser/feature_guide/notifications/internal/feature_notification_guide_service_impl_unittest.cc index da9e5d0..1eae77a7 100644 --- a/chrome/browser/feature_guide/notifications/internal/feature_notification_guide_service_impl_unittest.cc +++ b/chrome/browser/feature_guide/notifications/internal/feature_notification_guide_service_impl_unittest.cc
@@ -97,7 +97,7 @@ SegmentSelectionCallback callback) override { segmentation_platform::SegmentSelectionResult result; result.is_ready = true; - result.segment = optimization_guide::proto::OptimizationTarget:: + result.segment = segmentation_platform::proto::SegmentId:: OPTIMIZATION_TARGET_SEGMENTATION_CHROME_LOW_USER_ENGAGEMENT; std::move(callback).Run(result); } @@ -105,7 +105,7 @@ const std::string& segmentation_key) override { segmentation_platform::SegmentSelectionResult result; result.is_ready = true; - result.segment = optimization_guide::proto::OptimizationTarget:: + result.segment = segmentation_platform::proto::SegmentId:: OPTIMIZATION_TARGET_SEGMENTATION_CHROME_LOW_USER_ENGAGEMENT; return result; }
diff --git a/chrome/browser/flag-metadata.json b/chrome/browser/flag-metadata.json index 0ae2625..90fa4604 100644 --- a/chrome/browser/flag-metadata.json +++ b/chrome/browser/flag-metadata.json
@@ -64,6 +64,11 @@ "expiry_milestone": 115 }, { + "name": "adaptive-charging-for-testing", + "owners": [ "thanhdng" ], + "expiry_milestone": 115 + }, + { "name": "add-passwords-in-settings", "owners": [ "vidhanj", "mamir", "lizapopova" ], "expiry_milestone": 103 @@ -146,11 +151,6 @@ "expiry_milestone": 103 }, { - "name": "android-detailed-language-settings", - "owners": [ "perrier", "chrome-language@google.com" ], - "expiry_milestone": 101 - }, - { "name": "android-force-app-language-prompt", "owners": [ "perrier", "chrome-language@google.com" ], "expiry_milestone": 104 @@ -330,7 +330,7 @@ "jds@google.com", "chrome-voice@google.com" ], - "expiry_milestone": 103 + "expiry_milestone": 110 }, { "name": "assistant-consent-simplified-text", @@ -338,7 +338,7 @@ "jds@google.com", "chrome-voice@google.com" ], - "expiry_milestone": 103 + "expiry_milestone": 110 }, { "name": "assistant-consent-v2", @@ -346,7 +346,7 @@ "basiaz@google.com", "chrome-voice@google.com" ], - "expiry_milestone": 103 + "expiry_milestone": 110 }, { "name": "assistant-intent-page-url", @@ -355,7 +355,7 @@ "basiaz@google.com", "chrome-voice@google.com" ], - "expiry_milestone": 103 + "expiry_milestone": 110 }, { "name": "assistant-intent-translate-info", @@ -364,7 +364,7 @@ "basiaz@google.com", "chrome-voice@google.com" ], - "expiry_milestone": 103 + "expiry_milestone": 110 }, { "name": "audio-settings-page", @@ -3869,6 +3869,11 @@ "expiry_milestone": 110 }, { + "name": "launcher-hide-continue-section", + "owners": ["jamescook", "//ash/app_list/OWNERS"], + "expiry_milestone": 109 + }, + { "name": "launcher-lacros-integration", "owners": ["wrong", "thanhdng", "tby"], "expiry_milestone": 115 @@ -4420,7 +4425,7 @@ "basiaz@google.com", "chrome-voice@google.com" ], - "expiry_milestone": 103 + "expiry_milestone": 110 }, { "name": "omnibox-blur-with-escape", @@ -5828,7 +5833,7 @@ "basiaz@google.com", "chrome-voice@google.com" ], - "expiry_milestone": 103 + "expiry_milestone": 110 }, { "name": "translate-force-trigger-on-english", @@ -5844,7 +5849,7 @@ "basiaz@google.com", "chrome-voice@google.com" ], - "expiry_milestone": 103 + "expiry_milestone": 110 }, { "name": "translate-message-ui",
diff --git a/chrome/browser/flag_descriptions.cc b/chrome/browser/flag_descriptions.cc index cb2a9517..3537b55 100644 --- a/chrome/browser/flag_descriptions.cc +++ b/chrome/browser/flag_descriptions.cc
@@ -2988,11 +2988,6 @@ " This feature is only available for android P+ devices. Disabling it also " " disables SurfaceControl."; -const char kAndroidDetailedLanguageSettingsName[] = - "Detailed Language Settings"; -const char kAndroidDetailedLanguageSettingsDescription[] = - "Enable the new detailed language settings page"; - const char kAndroidForceAppLanguagePromptName[] = "Force second run app language prompt"; const char kAndroidForceAppLanguagePromptDescription[] = @@ -4096,6 +4091,15 @@ "Enable hardware-accelerated mjpeg decode for captured frame where " "available."; +const char kAdaptiveChargingForTestingName[] = + "Show adaptive charging notifications for testing"; +const char kAdaptiveChargingForTestingDescription[] = + "Show adaptive charging notifications and nudges for testing. This is " + "meant to be used by developers to test the feature UI only. The " + "notifications will be shown after the device is plugged in to the " + "charger. Please DO NOT enable this if you're not a developer who wants to " + "test the UI of the adaptive charging feature."; + const char kAdaptiveChargingName[] = "Enable adaptive charging feature"; const char kAdaptiveChargingDescription[] = "Show settings to enable/disable adaptive charging feature."; @@ -5540,6 +5544,12 @@ "When enabled, if a user removes a continue section suggestion, a dialog " "will appear on the launcher requesting feedback on the suggestions shown."; +const char kLauncherHideContinueSectionName[] = + "Launcher hide continue section"; +const char kLauncherHideContinueSectionDescription[] = + "Adds a 'Hide all suggestions' option to the continue section item " + "right-click menus."; + const char kLauncherNudgeName[] = "Enable launcher nudge"; const char kLauncherNudgeDescription[] = "Enables nudges that bring new users' attention to the launcher button.";
diff --git a/chrome/browser/flag_descriptions.h b/chrome/browser/flag_descriptions.h index 57b445a..90ebbfb 100644 --- a/chrome/browser/flag_descriptions.h +++ b/chrome/browser/flag_descriptions.h
@@ -1684,9 +1684,6 @@ extern const char kAImageReaderName[]; extern const char kAImageReaderDescription[]; -extern const char kAndroidDetailedLanguageSettingsName[]; -extern const char kAndroidDetailedLanguageSettingsDescription[]; - extern const char kAndroidForceAppLanguagePromptName[]; extern const char kAndroidForceAppLanguagePromptDescription[]; @@ -2346,6 +2343,9 @@ extern const char kAdaptiveChargingName[]; extern const char kAdaptiveChargingDescription[]; +extern const char kAdaptiveChargingForTestingName[]; +extern const char kAdaptiveChargingForTestingDescription[]; + extern const char kAllowDisableTouchpadHapticFeedbackName[]; extern const char kAllowDisableTouchpadHapticFeedbackDescription[]; @@ -3173,6 +3173,9 @@ extern const char kLauncherFeedbackOnContinueSectionRemoveName[]; extern const char kLauncherFeedbackOnContinueSectionRemoveDescription[]; +extern const char kLauncherHideContinueSectionName[]; +extern const char kLauncherHideContinueSectionDescription[]; + extern const char kLauncherNudgeName[]; extern const char kLauncherNudgeDescription[];
diff --git a/chrome/browser/headless/headless_mode_browsertest_win.cc b/chrome/browser/headless/headless_mode_browsertest_win.cc index 9257f6f..bcb49a4 100644 --- a/chrome/browser/headless/headless_mode_browsertest_win.cc +++ b/chrome/browser/headless/headless_mode_browsertest_win.cc
@@ -58,3 +58,53 @@ EXPECT_TRUE(browser()->window()->IsVisible()); EXPECT_FALSE(::IsWindowVisible(desktop_window_hwnd)); } + +IN_PROC_BROWSER_TEST_F(HeadlessModeBrowserTest, + MinimizedRestoredWindowVisibility) { + DesktopWindowTreeHostWinWrapper* desktop_window_tree_host = + static_cast<DesktopWindowTreeHostWinWrapper*>( + browser()->window()->GetNativeWindow()->GetHost()); + HWND desktop_window_hwnd = desktop_window_tree_host->GetHWND(); + + // Verify initial state. + ASSERT_FALSE(browser()->window()->IsMinimized()); + EXPECT_TRUE(browser()->window()->IsVisible()); + EXPECT_FALSE(::IsWindowVisible(desktop_window_hwnd)); + + // Verify minimized state. + browser()->window()->Minimize(); + ASSERT_TRUE(browser()->window()->IsMinimized()); + EXPECT_TRUE(browser()->window()->IsVisible()); + EXPECT_FALSE(::IsWindowVisible(desktop_window_hwnd)); + + // Verify restored state. + browser()->window()->Restore(); + ASSERT_FALSE(browser()->window()->IsMinimized()); + EXPECT_TRUE(browser()->window()->IsVisible()); + EXPECT_FALSE(::IsWindowVisible(desktop_window_hwnd)); +} + +IN_PROC_BROWSER_TEST_F(HeadlessModeBrowserTest, + MaximizedRestoredWindowVisibility) { + DesktopWindowTreeHostWinWrapper* desktop_window_tree_host = + static_cast<DesktopWindowTreeHostWinWrapper*>( + browser()->window()->GetNativeWindow()->GetHost()); + HWND desktop_window_hwnd = desktop_window_tree_host->GetHWND(); + + // Verify initial state. + ASSERT_FALSE(browser()->window()->IsMaximized()); + EXPECT_TRUE(browser()->window()->IsVisible()); + EXPECT_FALSE(::IsWindowVisible(desktop_window_hwnd)); + + // Verify maximized state. + browser()->window()->Maximize(); + ASSERT_TRUE(browser()->window()->IsMaximized()); + EXPECT_TRUE(browser()->window()->IsVisible()); + EXPECT_FALSE(::IsWindowVisible(desktop_window_hwnd)); + + // Verify restored state. + browser()->window()->Restore(); + ASSERT_FALSE(browser()->window()->IsMaximized()); + EXPECT_TRUE(browser()->window()->IsVisible()); + EXPECT_FALSE(::IsWindowVisible(desktop_window_hwnd)); +}
diff --git a/chrome/browser/media/router/discovery/access_code/access_code_cast_feature.cc b/chrome/browser/media/router/discovery/access_code/access_code_cast_feature.cc index 6de1dfa..5354a36a 100644 --- a/chrome/browser/media/router/discovery/access_code/access_code_cast_feature.cc +++ b/chrome/browser/media/router/discovery/access_code/access_code_cast_feature.cc
@@ -21,7 +21,7 @@ namespace features { // Enables remembering of access code cast devices. const base::Feature kAccessCodeCastRememberDevices{ - "AccessCodeCastRememberDevices", base::FEATURE_DISABLED_BY_DEFAULT}; + "AccessCodeCastRememberDevices", base::FEATURE_ENABLED_BY_DEFAULT}; } // namespace features namespace media_router {
diff --git a/chrome/browser/media/router/discovery/access_code/access_code_cast_sink_service.cc b/chrome/browser/media/router/discovery/access_code/access_code_cast_sink_service.cc index 942a8f7..5f9436c 100644 --- a/chrome/browser/media/router/discovery/access_code/access_code_cast_sink_service.cc +++ b/chrome/browser/media/router/discovery/access_code/access_code_cast_sink_service.cc
@@ -36,13 +36,13 @@ namespace { // Connect timeout value when opening a Cast socket. -const int kConnectTimeoutInSeconds = 1; +const int kConnectTimeoutInSeconds = 2; // Amount of idle time to wait before pinging the Cast device. -const int kPingIntervalInSeconds = 1; +const int kPingIntervalInSeconds = 4; // Amount of idle time to wait before disconnecting. -const int kLivenessTimeoutInSeconds = 2; +const int kLivenessTimeoutInSeconds = 8; using SinkSource = CastDeviceCountMetrics::SinkSource; using ChannelOpenedCallback = base::OnceCallback<void(bool)>; @@ -243,13 +243,16 @@ // another (preseentation). There was a pause before this method was called, // so check again to see if there's an active route for this sink. Only expire // the sink if a new route wasn't established during the pause. - auto route_id = HasActiveRoute(sink->id()); + auto route_id = GetActiveRouteId(sink->id()); // Only remove the sink if there is still no active routes for this sink. if (base::FeatureList::IsEnabled(features::kAccessCodeCastRememberDevices)) { + // If a sink is pending expiration that means we can + // remove it from the media router. if (!route_id.has_value() && pending_expirations_.count(sink->id())) { RemoveSinkIdFromAllEntries(sink->id()); RemoveMediaSinkFromRouter(sink); + pending_expirations_.erase(sink->id()); } } else { if (!route_id.has_value()) { @@ -319,17 +322,27 @@ "The sink already exists in the media router, no channel " "needs to be opened.", sink.id(), "", ""); + + // The logic below only pertains to the addition of access code devices that + // were added via access code (not via stored devices). + if (sink.cast_data().discovery_type != + CastDiscoveryType::kAccessCodeManualEntry) { + // We must call the |add_sink_callback| in all conditional branches. + std::move(add_sink_callback).Run(AddSinkResultCode::OK, sink.id()); + return; + } // Check to see if this sink has an active route. If so, we need to // terminate the route before alerting the dialog to discovery success. // This is because any attempt to start a route on a sink that already has // one won't be successful. - auto route_id = HasActiveRoute(sink.id()); + auto route_id = GetActiveRouteId(sink.id()); if (route_id.has_value()) { - media_router_->GetLogger()->LogInfo( - mojom::LogCategory::kDiscovery, kLoggerComponent, - "There was an existing route when discovery occurred, attempting to " - "terminate it.", - sink.id(), "", ""); + media_router_->GetLogger()->LogInfo(mojom::LogCategory::kDiscovery, + kLoggerComponent, + "There was an existing route when " + "discovery occurred, attempting to " + "terminate it.", + sink.id(), "", ""); media_router_->TerminateRoute(route_id.value()); pending_callbacks_.emplace(sink.id(), std::move(add_sink_callback)); } else { @@ -348,21 +361,51 @@ auto returned_channel_cb = base::BindPostTask(task_runner_, std::move(channel_cb)); - auto backoff_entry = std::make_unique<net::BackoffEntry>(&backoff_policy_); media_router_->GetLogger()->LogInfo( mojom::LogCategory::kDiscovery, kLoggerComponent, "Attempting to open a cast channel.", sink.id(), "", ""); + + switch (sink.cast_data().discovery_type) { + // For the manual entry case we use our own specific back off and open + // params so that failure happens much faster. + case CastDiscoveryType::kAccessCodeManualEntry: { + auto backoff_entry = + std::make_unique<net::BackoffEntry>(&backoff_policy_); + OpenChannelWithParams(std::move(backoff_entry), sink, + std::move(returned_channel_cb), + CreateCastSocketOpenParams(sink)); + break; + } + // For all other cases (such as remembered devices), just use the default + // parameters that the CastMediaSinkServiceImpl already uses. + default: { + base::PostTaskAndReplyWithResult( + cast_media_sink_service_impl_->task_runner().get(), FROM_HERE, + base::BindOnce(&CastMediaSinkServiceImpl::CreateCastSocketOpenParams, + base::Unretained(cast_media_sink_service_impl_), sink), + base::BindOnce(&AccessCodeCastSinkService::OpenChannelWithParams, + weak_ptr_factory_.GetWeakPtr(), nullptr, sink, + std::move(returned_channel_cb))); + } + } +} + +void AccessCodeCastSinkService::OpenChannelWithParams( + std::unique_ptr<net::BackoffEntry> backoff_entry, + const MediaSinkInternal& sink, + base::OnceCallback<void(bool)> channel_opened_cb, + cast_channel::CastSocketOpenParams open_params) { cast_media_sink_service_impl_->task_runner()->PostTask( FROM_HERE, base::BindOnce(&CastMediaSinkServiceImpl::OpenChannel, base::Unretained(cast_media_sink_service_impl_), sink, std::move(backoff_entry), SinkSource::kAccessCode, - std::move(returned_channel_cb), + std::move(channel_opened_cb), CreateCastSocketOpenParams(sink))); } -absl::optional<const MediaRoute::Id> AccessCodeCastSinkService::HasActiveRoute( - const MediaSink::Id& sink_id) { +absl::optional<const MediaRoute::Id> +AccessCodeCastSinkService::GetActiveRouteId(const MediaSink::Id& sink_id) { auto routes = media_router_->GetCurrentRoutes(); auto route_it = std::find_if(routes.begin(), routes.end(), [&sink_id](const MediaRoute& route) { @@ -589,12 +632,10 @@ void AccessCodeCastSinkService::AddStoredDevicesToMediaRouter( const std::vector<MediaSinkInternal> cast_sinks) { - // Let the media router handle addition. - cast_media_sink_service_impl_->task_runner()->PostTask( - FROM_HERE, - base::BindOnce(&CastMediaSinkServiceImpl::OpenChannelsWithRandomizedDelay, - base::Unretained(cast_media_sink_service_impl_), - cast_sinks, SinkSource::kAccessCode)); + std::vector<MediaSinkInternal> cast_sinks_to_add; + for (auto cast_sink : cast_sinks) { + AddSinkToMediaRouter(cast_sink, base::DoNothing()); + } } void AccessCodeCastSinkService::OnExpiration(const MediaSinkInternal& sink) { @@ -606,7 +647,7 @@ "references.", sink.id(), "", ""); - auto route_id = HasActiveRoute(sink.id()); + auto route_id = GetActiveRouteId(sink.id()); // The given sink still has an active route, don't remove it yet and wait for // the route to end before we expire it. if (route_id.has_value()) { @@ -636,6 +677,11 @@ if (!sink) { return; } + DCHECK(!GetActiveRouteId(sink->id()).has_value()) + << "This sink " + sink->id() + + " still has an active route, we should not be removing it!"; + if (GetActiveRouteId(sink->id()).has_value()) + return; media_router_->GetLogger()->LogInfo( mojom::LogCategory::kDiscovery, kLoggerComponent, "Attempting to disconnect and remove the cast sink from " @@ -696,13 +742,22 @@ void AccessCodeCastSinkService::RemoveExistingSinksOnNetwork() { for (auto& sink_id_keypair : current_session_expiration_timers_) { - // Must find the sink from media router for removal since it has more total - // information. + auto sink_id = sink_id_keypair.first; + // If there is an active route for this sink -- don't attempt to remove it. + // In this case we let the Media Router handle removals from the media + // router when a network is changed with an active route. + if (GetActiveRouteId(sink_id).has_value()) { + continue; + } + + // There are no active routes for this sink so it is safe to remove from the + // media router. Must find the sink from media router for removal since it + // has more total information. base::PostTaskAndReplyWithResult( cast_media_sink_service_impl_->task_runner().get(), FROM_HERE, base::BindOnce(&CastMediaSinkServiceImpl::GetSinkById, base::Unretained(cast_media_sink_service_impl_), - sink_id_keypair.first), + sink_id), base::BindOnce(&AccessCodeCastSinkService::RemoveMediaSinkFromRouter, weak_ptr_factory_.GetWeakPtr())); } @@ -713,7 +768,6 @@ if (base::FeatureList::IsEnabled(features::kAccessCodeCastRememberDevices)) { RemoveExistingSinksOnNetwork(); ResetExpirationTimers(); - pending_expirations_.clear(); InitAllStoredDevices(); } }
diff --git a/chrome/browser/media/router/discovery/access_code/access_code_cast_sink_service.h b/chrome/browser/media/router/discovery/access_code/access_code_cast_sink_service.h index 742bb4a..8a3135d 100644 --- a/chrome/browser/media/router/discovery/access_code/access_code_cast_sink_service.h +++ b/chrome/browser/media/router/discovery/access_code/access_code_cast_sink_service.h
@@ -97,6 +97,10 @@ AccessCodeCastDeviceRemovedAfterRouteEndsExpirationDisabled); FRIEND_TEST_ALL_PREFIXES(AccessCodeCastSinkServiceTest, AddExistingSinkToMediaRouterWithRoute); + FRIEND_TEST_ALL_PREFIXES(AccessCodeCastSinkServiceTest, + TestChangeNetworkWithRouteActive); + FRIEND_TEST_ALL_PREFIXES(AccessCodeCastSinkServiceTest, + TestChangeNetworkWithRouteActiveExpiration); // media_router::MediaRoutesObserver: void OnRoutesUpdated(const std::vector<MediaRoute>& routes) override; @@ -160,6 +164,10 @@ TestChangeEnabledPref); FRIEND_TEST_ALL_PREFIXES(AccessCodeCastSinkServiceTest, TestChangeDurationPref); + FRIEND_TEST_ALL_PREFIXES(AccessCodeCastSinkServiceTest, + TestChangeNetworkWithRouteActive); + FRIEND_TEST_ALL_PREFIXES(AccessCodeCastSinkServiceTest, + TestChangeNetworkWithRouteActiveExpiration); // Constructor used for testing. AccessCodeCastSinkService( @@ -186,10 +194,14 @@ void OpenChannelIfNecessary(const MediaSinkInternal& sink, AddSinkResultCallback add_sink_callback, bool has_sink); + void OpenChannelWithParams(std::unique_ptr<net::BackoffEntry> backoff_entry, + const MediaSinkInternal& sink, + base::OnceCallback<void(bool)> channel_opened_cb, + cast_channel::CastSocketOpenParams open_params); // Returns a MediaRoute if the given |sink_id| corresponds to a route // currently active in the media router. - absl::optional<const MediaRoute::Id> HasActiveRoute( + absl::optional<const MediaRoute::Id> GetActiveRouteId( const MediaSink::Id& sink_id); void InitAllStoredDevices(); @@ -198,6 +210,9 @@ base::TimeDelta CalculateDurationTillExpiration(const MediaSink::Id& sink_id); void OnExpiration(const MediaSinkInternal& sink); + + // It is the responsibility of the caller to ensure that no active routes + // remain before this function is called. void RemoveMediaSinkFromRouter(const MediaSinkInternal* sink); const base::Value::List FetchStoredDevices(); @@ -271,7 +286,9 @@ std::map<MediaSink::Id, std::unique_ptr<base::OneShotTimer>> current_session_expiration_timers_; - // Set of devices that have expired but still have an open route. + // Set of devices that have expired but still have an open route. These + // devices are removed from the media router AND removed from the pref + // service. std::set<MediaSink::Id> pending_expirations_; scoped_refptr<base::SequencedTaskRunner> task_runner_;
diff --git a/chrome/browser/media/router/discovery/access_code/access_code_cast_sink_service_unittest.cc b/chrome/browser/media/router/discovery/access_code/access_code_cast_sink_service_unittest.cc index d905aec..8c01a798 100644 --- a/chrome/browser/media/router/discovery/access_code/access_code_cast_sink_service_unittest.cc +++ b/chrome/browser/media/router/discovery/access_code/access_code_cast_sink_service_unittest.cc
@@ -152,6 +152,22 @@ discovery_network_monitor_->OnConnectionChanged(connection_type); } + void ExpectOpenChannels(std::vector<MediaSinkInternal> cast_sinks, + int num_times) { + for (auto sink : cast_sinks) { + EXPECT_CALL(*mock_cast_media_sink_service_impl(), + OpenChannel(sink, _, SinkSource::kAccessCode, _, _)) + .Times(num_times); + } + } + + void ExpectHasSink(std::vector<MediaSinkInternal> cast_sinks, int num_times) { + for (auto sink : cast_sinks) { + EXPECT_CALL(*mock_cast_media_sink_service_impl(), HasSink(sink.id())) + .Times(num_times); + } + } + protected: content::BrowserTaskEnvironment task_environment_{ base::test::TaskEnvironment::TimeSource::MOCK_TIME}; @@ -286,7 +302,7 @@ TEST_F(AccessCodeCastSinkServiceTest, AccessCodeCastDeviceRemovedAfterRouteEndsExpirationDisabled) { feature_list_.Reset(); - feature_list_.Init(); + feature_list_.InitAndDisableFeature(features::kAccessCodeCastRememberDevices); // Test to see that an AccessCode cast sink will be removed after the session // is ended. mock_time_task_runner()->FastForwardUntilNoTasksRemain(); @@ -363,6 +379,9 @@ // exists in the media router. MockAddSinkResultCallback mock_callback; MediaSinkInternal cast_sink1 = CreateCastSink(1); + auto cast_data = cast_sink1.cast_data(); + cast_data.discovery_type = CastDiscoveryType::kAccessCodeManualEntry; + cast_sink1.set_cast_data(cast_data); EXPECT_CALL(*mock_cast_media_sink_service_impl(), OpenChannel(cast_sink1, _, SinkSource::kAccessCode, _, _)) @@ -416,6 +435,9 @@ // exist. MockAddSinkResultCallback mock_callback; MediaSinkInternal cast_sink1 = CreateCastSink(1); + auto cast_data = cast_sink1.cast_data(); + cast_data.discovery_type = CastDiscoveryType::kAccessCodeManualEntry; + cast_sink1.set_cast_data(cast_data); EXPECT_CALL(*mock_cast_media_sink_service_impl(), OpenChannel(cast_sink1, _, SinkSource::kAccessCode, _, _)); @@ -453,15 +475,12 @@ access_code_cast_sink_service_->OnAccessCodeValidated( mock_callback.Get(), discovery_device_proto, AddSinkResultCode::OK); - // Assume sink is not present in the Media Router so a call to OpenChannel is - // made. - access_code_cast_sink_service_->OpenChannelIfNecessary( - cast_sink1, mock_callback.Get(), false); - // Channel successfully opens. access_code_cast_sink_service_->OnChannelOpenedResult(mock_callback.Get(), "123456", true); mock_time_task_runner()->FastForwardUntilNoTasksRemain(); + FastForwardUiAndIoTasks(); + mock_time_task_runner()->FastForwardUntilNoTasksRemain(); } TEST_F(AccessCodeCastSinkServiceTest, InvalidDiscoveryDevice) { @@ -550,9 +569,8 @@ access_code_cast_sink_service_->ValidateDeviceFromSinkId(cast_sink3.id()) .value()); - EXPECT_CALL(*mock_cast_media_sink_service_impl(), - OpenChannels(cast_sinks_ethernet, SinkSource::kAccessCode)) - .Times(1); + ExpectOpenChannels(cast_sinks_ethernet, 1); + ExpectHasSink(cast_sinks_ethernet, 1); FastForwardUiAndIoTasks(); @@ -568,6 +586,8 @@ FastForwardUiAndIoTasks(); content::RunAllTasksUntilIdle(); mock_time_task_runner()->FastForwardUntilNoTasksRemain(); + FastForwardUiAndIoTasks(); + mock_time_task_runner()->FastForwardUntilNoTasksRemain(); } TEST_F(AccessCodeCastSinkServiceTest, TestChangeNetworksExpiration) { @@ -606,9 +626,11 @@ access_code_cast_sink_service_->ValidateDeviceFromSinkId(cast_sink3.id()) .value()); - EXPECT_CALL(*mock_cast_media_sink_service_impl(), - OpenChannels(cast_sinks_ethernet, SinkSource::kAccessCode)) - .Times(1); + // Overall this unit test should call OpenChannel for each cast sink twice. + // This is on init stored devices and then connecting to a new network will + // trigger a call to add every stored cast device back again. + ExpectOpenChannels(cast_sinks_ethernet, 2); + ExpectHasSink(cast_sinks_ethernet, 2); FastForwardUiAndIoTasks(); @@ -632,12 +654,6 @@ ->GetDict() .empty()); - // Connecting to a new network will trigger a call to add every stored cast - // device back again. - EXPECT_CALL(*mock_cast_media_sink_service_impl(), - OpenChannels(cast_sinks_ethernet, SinkSource::kAccessCode)) - .Times(1); - // When the network changes, the sinks on that network should be removed. EXPECT_CALL(*mock_cast_media_sink_service_impl(), DisconnectAndRemoveSink(cast_sink1)); @@ -708,9 +724,11 @@ access_code_cast_sink_service_->ValidateDeviceFromSinkId(cast_sink3.id()) .value()); - EXPECT_CALL(*mock_cast_media_sink_service_impl(), - OpenChannels(cast_sinks_ethernet, SinkSource::kAccessCode)) - .Times(1); + // Overall this unit test should call OpenChannel for each cast sink twice. + // This is on init stored devices and then connecting to a new network will + // trigger a call to add every stored cast device back again. + ExpectOpenChannels(cast_sinks_ethernet, 2); + ExpectHasSink(cast_sinks_ethernet, 2); FastForwardUiAndIoTasks(); @@ -734,12 +752,6 @@ ->GetDict() .empty()); - // Connecting to a new network will trigger a call to add every stored cast - // device back again. - EXPECT_CALL(*mock_cast_media_sink_service_impl(), - OpenChannels(cast_sinks_ethernet, SinkSource::kAccessCode)) - .Times(1); - content::RunAllTasksUntilIdle(); mock_time_task_runner()->FastForwardUntilNoTasksRemain(); @@ -1084,4 +1096,128 @@ ->GetCurrentDelay()); } +TEST_F(AccessCodeCastSinkServiceTest, TestChangeNetworkWithRouteActive) { + // This test ensures that a call to remove a media sink will NOT be made if + // there is currently an active route. + SetDeviceDurationPrefForTest(base::Seconds(10000)); + const MediaSinkInternal cast_sink1 = CreateCastSink(1); + + mock_cast_media_sink_service_impl()->AddSinkForTest(cast_sink1); + access_code_cast_sink_service_->StoreSinkInPrefs(&cast_sink1); + MediaRoute media_route_cast = CreateRouteForTesting(cast_sink1); + std::vector<MediaRoute> route_list = {media_route_cast}; + + EXPECT_CALL(*mock_cast_media_sink_service_impl(), HasSink(cast_sink1.id())); + // Simulate that this cast sink has an open route. + access_code_cast_sink_service_->media_routes_observer_->OnRoutesUpdated( + route_list); + ON_CALL(*router_, GetCurrentRoutes()) + .WillByDefault(Return(std::vector<MediaRoute>{media_route_cast})); + + // Since the route has not changed, no call to remove the sink should have + // been made after the network changed. + EXPECT_CALL(*mock_cast_media_sink_service_impl(), + DisconnectAndRemoveSink(cast_sink1)) + .Times(0); + + FastForwardUiAndIoTasks(); + content::RunAllTasksUntilIdle(); + + fake_network_info_ = fake_wifi_info_; + ChangeConnectionType(network::mojom::ConnectionType::CONNECTION_WIFI); + + content::RunAllTasksUntilIdle(); + mock_time_task_runner()->FastForwardBy(base::Seconds(100)); + + // Simulate that the route has ended. + ON_CALL(*router_, GetCurrentRoutes()) + .WillByDefault(Return(std::vector<MediaRoute>{})); + access_code_cast_sink_service_->media_routes_observer_->OnRoutesUpdated({}); + + // The sink should NOT now be removed from the media router since it was not + // expired. + EXPECT_CALL(*mock_cast_media_sink_service_impl(), + DisconnectAndRemoveSink(cast_sink1)) + .Times(0); + FastForwardUiAndIoTasks(); + content::RunAllTasksUntilIdle(); + + // The sink did not expire in this situation so it should still exist in the + // pref service. + EXPECT_FALSE( + access_code_cast_sink_service_->pref_updater_->GetDeviceAddedTimeDict() + ->GetDict() + .empty()); + EXPECT_FALSE(access_code_cast_sink_service_->pref_updater_->GetDevicesDict() + ->GetDict() + .empty()); +} + +TEST_F(AccessCodeCastSinkServiceTest, + TestChangeNetworkWithRouteActiveExpiration) { + // This test ensures that a call to remove a media sink will NOT be made if + // there is currently an active route, this time when the sink has expired + // before the network has changed. + SetDeviceDurationPrefForTest(base::Seconds(100)); + MediaSinkInternal cast_sink1 = CreateCastSink(1); + + auto cast_data = cast_sink1.cast_data(); + cast_data.discovery_type = CastDiscoveryType::kAccessCodeRememberedDevice; + cast_sink1.set_cast_data(cast_data); + + mock_cast_media_sink_service_impl()->AddSinkForTest(cast_sink1); + access_code_cast_sink_service_->StoreSinkInPrefs(&cast_sink1); + access_code_cast_sink_service_->SetExpirationTimer(&cast_sink1); + MediaRoute media_route_cast = CreateRouteForTesting(cast_sink1); + std::vector<MediaRoute> route_list = {media_route_cast}; + + EXPECT_CALL(*mock_cast_media_sink_service_impl(), HasSink(cast_sink1.id())); + // Simulate that this cast sink has an open route. + access_code_cast_sink_service_->media_routes_observer_->OnRoutesUpdated( + route_list); + ON_CALL(*router_, GetCurrentRoutes()) + .WillByDefault(Return(std::vector<MediaRoute>{media_route_cast})); + + // Since the route has not changed, no call to remove the sink should have + // been made after the network changed. + EXPECT_CALL(*mock_cast_media_sink_service_impl(), + DisconnectAndRemoveSink(cast_sink1)) + .Times(0); + + // Expire the sink + mock_time_task_runner()->FastForwardBy(base::Seconds(300)); + task_environment_.FastForwardBy(base::Seconds(300)); + + FastForwardUiAndIoTasks(); + content::RunAllTasksUntilIdle(); + + fake_network_info_ = fake_wifi_info_; + ChangeConnectionType(network::mojom::ConnectionType::CONNECTION_WIFI); + + content::RunAllTasksUntilIdle(); + mock_time_task_runner()->FastForwardBy(base::Seconds(300)); + + // Simulate that the route has ended. + ON_CALL(*router_, GetCurrentRoutes()) + .WillByDefault(Return(std::vector<MediaRoute>{})); + access_code_cast_sink_service_->media_routes_observer_->OnRoutesUpdated({}); + + // The sink should now be removed from the media router. + EXPECT_CALL(*mock_cast_media_sink_service_impl(), + DisconnectAndRemoveSink(cast_sink1)); + FastForwardUiAndIoTasks(); + content::RunAllTasksUntilIdle(); + FastForwardUiAndIoTasks(); + + // The sink did expire in this situation so it should not exist in the pref + // service. + EXPECT_TRUE( + access_code_cast_sink_service_->pref_updater_->GetDeviceAddedTimeDict() + ->GetDict() + .empty()); + EXPECT_TRUE(access_code_cast_sink_service_->pref_updater_->GetDevicesDict() + ->GetDict() + .empty()); +} + } // namespace media_router
diff --git a/chrome/browser/media/router/discovery/mdns/cast_media_sink_service_impl.h b/chrome/browser/media/router/discovery/mdns/cast_media_sink_service_impl.h index 1e3d7e1..e1682cd 100644 --- a/chrome/browser/media/router/discovery/mdns/cast_media_sink_service_impl.h +++ b/chrome/browser/media/router/discovery/mdns/cast_media_sink_service_impl.h
@@ -122,6 +122,13 @@ // service. virtual void DisconnectAndRemoveSink(const MediaSinkInternal& sink); + // Returns cast socket open parameters. + // Connect / liveness timeout value are dynamically calculated + // based on results of previous connection attempts. + // |sink|: Sink to open cast channel to. + cast_channel::CastSocketOpenParams CreateCastSocketOpenParams( + const MediaSinkInternal& sink); + private: friend class CastMediaSinkServiceImplTest; FRIEND_TEST_ALL_PREFIXES(CastMediaSinkServiceImplTest, @@ -243,13 +250,6 @@ // DiscoveryNetworkMonitor::Observer implementation void OnNetworksChanged(const std::string& network_id) override; - // Returns cast socket open parameters. Parameters are read from Finch. - // Connect / liveness timeout value are dynamically calculated - // based on results of previous connection attempts. - // |sink|: Sink to open cast channel to. - cast_channel::CastSocketOpenParams CreateCastSocketOpenParams( - const MediaSinkInternal& sink); - // Invoked when opening cast channel on IO thread completes. // |cast_sink|: Cast sink created from mDNS service description, DIAL sink, or // access code sink.
diff --git a/chrome/browser/media/webrtc/native_desktop_media_list.cc b/chrome/browser/media/webrtc/native_desktop_media_list.cc index 6d8c9d9..f5e42037 100644 --- a/chrome/browser/media/webrtc/native_desktop_media_list.cc +++ b/chrome/browser/media/webrtc/native_desktop_media_list.cc
@@ -43,6 +43,7 @@ #if BUILDFLAG(IS_WIN) #include <windows.h> +#include "base/strings/string_util_win.h" #include "ui/views/widget/desktop_aura/desktop_window_tree_host_win.h" #endif @@ -101,16 +102,25 @@ } #if BUILDFLAG(IS_WIN) -BOOL CALLBACK TopLevelCurrentProcessHwndCollector(HWND hwnd, LPARAM param) { +// These Collector functions are repeatedly invoked by `::EnumWindows` and they +// add HWNDs to the vector contained in `param`. Return TRUE to continue the +// enumeration or FALSE to end early. +// +// Collects all capturable HWNDs which are owned by the current process. +BOOL CALLBACK CapturableCurrentProcessHwndCollector(HWND hwnd, LPARAM param) { DWORD process_id; ::GetWindowThreadProcessId(hwnd, &process_id); if (process_id != ::GetCurrentProcessId()) return TRUE; + // Skip windows that aren't visible or are minimized. + if (!::IsWindowVisible(hwnd) || ::IsIconic(hwnd)) + return TRUE; + // Skip windows which are not presented in the taskbar, e.g. the "Restore // pages?" window. HWND owner = ::GetWindow(hwnd, GW_OWNER); - LONG exstyle = GetWindowLong(hwnd, GWL_EXSTYLE); + LONG exstyle = ::GetWindowLong(hwnd, GWL_EXSTYLE); if (owner && !(exstyle & WS_EX_APPWINDOW)) return TRUE; @@ -119,6 +129,8 @@ return TRUE; } +// Collects all HWNDs, which are enumerated in z-order, to create a reference +// for sorting. BOOL CALLBACK AllHwndCollector(HWND hwnd, LPARAM param) { auto* hwnds = reinterpret_cast<std::vector<HWND>*>(param); hwnds->push_back(hwnd); @@ -243,9 +255,9 @@ FormatSources(sources, view_dialog_id, type_); #if BUILDFLAG(IS_WIN) - // If |add_current_process_windows_| is set to false, |capturer_| will - // find the windows owned by the current process for us. Otherwise, we must do - // this. + // If |add_current_process_windows_| is set to false, |capturer_| will have + // found the windows owned by the current process for us. Otherwise, we must + // do this. if (add_current_process_windows_) { DCHECK_EQ(type_, DesktopMediaList::Type::kWindow); // WebRTC returns the windows in order of highest z-order to lowest, but @@ -337,7 +349,7 @@ std::vector<DesktopMediaListBase::SourceDescription> NativeDesktopMediaList::Worker::GetCurrentProcessWindows() { std::vector<HWND> current_process_windows; - if (!::EnumWindows(TopLevelCurrentProcessHwndCollector, + if (!::EnumWindows(CapturableCurrentProcessHwndCollector, reinterpret_cast<LPARAM>(¤t_process_windows))) { return std::vector<SourceDescription>(); }
diff --git a/chrome/browser/media/webrtc/native_desktop_media_list_unittest.cc b/chrome/browser/media/webrtc/native_desktop_media_list_unittest.cc index 6eeb581..2ef1b1b 100644 --- a/chrome/browser/media/webrtc/native_desktop_media_list_unittest.cc +++ b/chrome/browser/media/webrtc/native_desktop_media_list_unittest.cc
@@ -41,6 +41,7 @@ #if BUILDFLAG(IS_WIN) #include <windows.h> +#include "base/strings/string_util_win.h" #endif using content::DesktopMediaID; @@ -262,8 +263,13 @@ delete; void TearDown() override { - for (size_t i = 0; i < desktop_widgets_.size(); i++) - desktop_widgets_[i].reset(); +#if BUILDFLAG(IS_WIN) + if (window_open_) + DestroyTestWindow(window_info_); +#endif // BUILDFLAG(IS_WIN + + for (auto& desktop_widget : desktop_widgets_) + desktop_widget.reset(); ChromeViewsTestBase::TearDown(); } @@ -339,14 +345,55 @@ window_list_.erase(window_list_.begin() + i); native_aura_id_map_.erase(native_id); } - #endif // defined(USE_AURA) - void AddWindowsAndVerify(bool has_view_dialog) { - window_capturer_ = new FakeWindowCapturer(); + void CreateCapturerAndModel() { + webrtc::DesktopCaptureOptions options = + content::desktop_capture::CreateDesktopCaptureOptions(); + +#if BUILDFLAG(IS_WIN) + // This option should always be false on Windows so we avoid a potential + // deadlock. + EXPECT_FALSE(options.enumerate_current_process_windows()); +#endif // BUILDFLAG(IS_WIN) + + window_capturer_ = new FakeWindowCapturer(options); + + // Only set `add_current_process_windows` if we're using real test windows. + // The tests that use fake windows will have their expectations fail if + // `model_` picks up other windows on the system. + bool add_current_process_windows = false; +#if BUILDFLAG(IS_WIN) + add_current_process_windows = window_open_; +#endif // BUILDFLAG(IS_WIN) model_ = std::make_unique<NativeDesktopMediaList>( DesktopMediaList::Type::kWindow, - base::WrapUnique(window_capturer_.get())); + base::WrapUnique(window_capturer_.get()), add_current_process_windows); + } + + void UpdateModel() { + base::RunLoop run_loop; + base::OnceClosure update_callback = + base::BindLambdaForTesting([&]() { run_loop.Quit(); }); + model_->Update(std::move(update_callback)); + run_loop.Run(); + } + + DesktopMediaList::Source GetSourceFromModel(content::DesktopMediaID::Id id) { + int source_count = model_->GetSourceCount(); + DesktopMediaList::Source source; + for (int i = 0; i < source_count; i++) { + source = model_->GetSource(i); + if (source.id.id == id) { + return source; + } + } + + return DesktopMediaList::Source(); + } + + void AddWindowsAndVerify(bool has_view_dialog) { + CreateCapturerAndModel(); // Set update period to reduce the time it takes to run tests. model_->SetUpdatePeriod(base::Milliseconds(20)); @@ -410,6 +457,13 @@ testing::Mock::VerifyAndClearExpectations(&observer_); } +#if BUILDFLAG(IS_WIN) + void CreateRealWindow() { + window_open_ = true; + window_info_ = CreateTestWindow(); + } +#endif // BUILDFLAG(IS_WIN) + protected: // Must be listed before |model_|, so it's destroyed last. MockObserver observer_; @@ -421,6 +475,11 @@ std::vector<std::unique_ptr<views::Widget>> desktop_widgets_; std::map<DesktopMediaID::Id, DesktopMediaID::Id> native_aura_id_map_; std::unique_ptr<NativeDesktopMediaList> model_; + +#if BUILDFLAG(IS_WIN) + bool window_open_ = false; + WindowInfo window_info_; +#endif // BUILDFLAG(IS_WIN) }; TEST_F(NativeDesktopMediaListTest, Windows) { @@ -614,10 +673,7 @@ // This test verifies that webrtc::DesktopCapturer::CaptureFrame() is not // called when the thumbnail size is empty. TEST_F(NativeDesktopMediaListTest, EmptyThumbnail) { - window_capturer_ = new FakeWindowCapturer(); - model_ = std::make_unique<NativeDesktopMediaList>( - DesktopMediaList::Type::kWindow, - base::WrapUnique(window_capturer_.get())); + CreateCapturerAndModel(); model_->SetThumbnailSize(gfx::Size()); // Set update period to reduce the time it takes to run tests. @@ -686,4 +742,43 @@ FROM_HERE, base::BindOnce(&DestroyTestWindow, info)); window_thread.Stop(); } + +TEST_F(NativeDesktopMediaListTest, CollectsCurrentProcessWindows) { + // We need a real window so we can ensure windows owned by the current + // process are picked up by `model_` even if they aren't enumerated by the + // capturer. + CreateRealWindow(); + CreateCapturerAndModel(); + UpdateModel(); + + // Ensure that `model_` is finding and adding the window to it's sources, and + // not getting it from the capturer. + webrtc::DesktopCapturer::SourceList source_list; + EXPECT_TRUE(window_capturer_->GetSourceList(&source_list)); + EXPECT_EQ(source_list.size(), 0ull); + + content::DesktopMediaID::Id window_id = + reinterpret_cast<intptr_t>(window_info_.hwnd); + DesktopMediaList::Source source = GetSourceFromModel(window_id); + EXPECT_EQ(source.id.id, window_id); + EXPECT_STREQ(base::as_wcstr(source.name.c_str()), kWideWindowTitle); +} + +TEST_F(NativeDesktopMediaListTest, MinimizedCurrentProcessWindows) { + CreateRealWindow(); + CreateCapturerAndModel(); + + webrtc::DesktopCapturer::SourceList source_list; + EXPECT_TRUE(window_capturer_->GetSourceList(&source_list)); + EXPECT_EQ(source_list.size(), 0ull); + + // If we minimize the window it should not appear in `model_`s sources. + ::ShowWindow(window_info_.hwnd, SW_MINIMIZE); + UpdateModel(); + DesktopMediaList::Source source = + GetSourceFromModel(reinterpret_cast<intptr_t>(window_info_.hwnd)); + + // We expect the source is not found. + EXPECT_EQ(source.id.id, content::DesktopMediaID::kNullId); +} #endif // BUILDFLAG(IS_WIN)
diff --git a/chrome/browser/ash/notifications/passphrase_textfield.cc b/chrome/browser/notifications/passphrase_textfield.cc similarity index 93% rename from chrome/browser/ash/notifications/passphrase_textfield.cc rename to chrome/browser/notifications/passphrase_textfield.cc index 244e706..567045b1 100644 --- a/chrome/browser/ash/notifications/passphrase_textfield.cc +++ b/chrome/browser/notifications/passphrase_textfield.cc
@@ -2,13 +2,13 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#include "chrome/browser/ash/notifications/passphrase_textfield.h" +#include "chrome/browser/notifications/passphrase_textfield.h" #include "base/no_destructor.h" #include "base/strings/utf_string_conversions.h" #include "ui/base/metadata/metadata_impl_macros.h" -namespace ash { +namespace chromeos { PassphraseTextfield::PassphraseTextfield() : Textfield(), show_fake_(false), changed_(true) { @@ -65,4 +65,4 @@ ADD_READONLY_PROPERTY_METADATA(bool, Changed) END_METADATA -} // namespace ash +} // namespace chromeos
diff --git a/chrome/browser/ash/notifications/passphrase_textfield.h b/chrome/browser/notifications/passphrase_textfield.h similarity index 81% rename from chrome/browser/ash/notifications/passphrase_textfield.h rename to chrome/browser/notifications/passphrase_textfield.h index 7300feb..6af36291 100644 --- a/chrome/browser/ash/notifications/passphrase_textfield.h +++ b/chrome/browser/notifications/passphrase_textfield.h
@@ -2,15 +2,15 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#ifndef CHROME_BROWSER_ASH_NOTIFICATIONS_PASSPHRASE_TEXTFIELD_H_ -#define CHROME_BROWSER_ASH_NOTIFICATIONS_PASSPHRASE_TEXTFIELD_H_ +#ifndef CHROME_BROWSER_NOTIFICATIONS_PASSPHRASE_TEXTFIELD_H_ +#define CHROME_BROWSER_NOTIFICATIONS_PASSPHRASE_TEXTFIELD_H_ #include <string> #include "ui/base/metadata/metadata_header_macros.h" #include "ui/views/controls/textfield/textfield.h" -namespace ash { +namespace chromeos { class PassphraseTextfield : public views::Textfield { public: @@ -40,6 +40,6 @@ bool changed_; }; -} // namespace ash +} // namespace chromeos -#endif // CHROME_BROWSER_ASH_NOTIFICATIONS_PASSPHRASE_TEXTFIELD_H_ +#endif // CHROME_BROWSER_NOTIFICATIONS_PASSPHRASE_TEXTFIELD_H_
diff --git a/chrome/browser/performance_manager/chrome_browser_main_extra_parts_performance_manager.cc b/chrome/browser/performance_manager/chrome_browser_main_extra_parts_performance_manager.cc index 5aa652c..72fd6e1f 100644 --- a/chrome/browser/performance_manager/chrome_browser_main_extra_parts_performance_manager.cc +++ b/chrome/browser/performance_manager/chrome_browser_main_extra_parts_performance_manager.cc
@@ -42,6 +42,10 @@ #endif // BUILDFLAG(IS_CHROMEOS_ASH) +#if BUILDFLAG(ENABLE_EXTENSIONS) +#include "chrome/browser/performance_manager/extension_watcher.h" +#endif + #if !BUILDFLAG(IS_ANDROID) #include "chrome/browser/performance_manager/mechanisms/page_freezer.h" #include "chrome/browser/performance_manager/policies/high_efficiency_mode_policy.h" @@ -194,6 +198,10 @@ std::make_unique<performance_manager::PageLiveStateDecoratorHelper>(); page_load_tracker_decorator_helper_ = std::make_unique<performance_manager::PageLoadTrackerDecoratorHelper>(); +#if BUILDFLAG(ENABLE_EXTENSIONS) + extension_watcher_ = + std::make_unique<performance_manager::ExtensionWatcher>(); +#endif #if !BUILDFLAG(IS_ANDROID) if (base::FeatureList::IsEnabled( @@ -214,6 +222,9 @@ g_browser_process->profile_manager()->RemoveObserver(this); profile_observations_.RemoveAllObservations(); +#if BUILDFLAG(ENABLE_EXTENSIONS) + extension_watcher_.reset(); +#endif page_load_tracker_decorator_helper_.reset(); page_live_state_data_helper_.reset(); page_load_metrics_observer_.reset();
diff --git a/chrome/browser/performance_manager/chrome_browser_main_extra_parts_performance_manager.h b/chrome/browser/performance_manager/chrome_browser_main_extra_parts_performance_manager.h index 80a0efe..08e1ef0 100644 --- a/chrome/browser/performance_manager/chrome_browser_main_extra_parts_performance_manager.h +++ b/chrome/browser/performance_manager/chrome_browser_main_extra_parts_performance_manager.h
@@ -13,6 +13,7 @@ #include "chrome/browser/profiles/profile.h" #include "chrome/browser/profiles/profile_manager_observer.h" #include "chrome/browser/profiles/profile_observer.h" +#include "extensions/buildflags/buildflags.h" class Profile; @@ -29,6 +30,10 @@ class PerformanceManagerFeatureObserverClient; class PerformanceManagerLifetime; +#if BUILDFLAG(ENABLE_EXTENSIONS) +class ExtensionWatcher; +#endif + namespace policies { class HighEfficiencyModePolicyHelper; } @@ -100,6 +105,10 @@ std::unique_ptr<performance_manager::PageLoadTrackerDecoratorHelper> page_load_tracker_decorator_helper_; +#if BUILDFLAG(ENABLE_EXTENSIONS) + std::unique_ptr<performance_manager::ExtensionWatcher> extension_watcher_; +#endif + #if !BUILDFLAG(IS_ANDROID) std::unique_ptr<performance_manager::policies::HighEfficiencyModePolicyHelper> high_efficiency_mode_policy_helper_;
diff --git a/chrome/browser/performance_manager/extension_watcher.cc b/chrome/browser/performance_manager/extension_watcher.cc new file mode 100644 index 0000000..ce62c314 --- /dev/null +++ b/chrome/browser/performance_manager/extension_watcher.cc
@@ -0,0 +1,36 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "chrome/browser/performance_manager/extension_watcher.h" + +#include "chrome/browser/browser_process.h" +#include "components/performance_manager/embedder/performance_manager_registry.h" +#include "extensions/browser/extension_host.h" + +namespace performance_manager { + +ExtensionWatcher::ExtensionWatcher() { + profile_manager_observation_.Observe(g_browser_process->profile_manager()); +} + +ExtensionWatcher::~ExtensionWatcher() = default; + +void ExtensionWatcher::OnProfileAdded(Profile* profile) { + extension_process_manager_observation_.AddObservation( + extensions::ProcessManager::Get(profile)); +} + +void ExtensionWatcher::OnBackgroundHostCreated( + extensions::ExtensionHost* host) { + auto* registry = PerformanceManagerRegistry::GetInstance(); + DCHECK(registry); + registry->SetPageType(host->host_contents(), PageType::kExtension); +} + +void ExtensionWatcher::OnProcessManagerShutdown( + extensions::ProcessManager* manager) { + extension_process_manager_observation_.RemoveObservation(manager); +} + +} // namespace performance_manager
diff --git a/chrome/browser/performance_manager/extension_watcher.h b/chrome/browser/performance_manager/extension_watcher.h new file mode 100644 index 0000000..19a9a2f --- /dev/null +++ b/chrome/browser/performance_manager/extension_watcher.h
@@ -0,0 +1,43 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CHROME_BROWSER_PERFORMANCE_MANAGER_EXTENSION_WATCHER_H_ +#define CHROME_BROWSER_PERFORMANCE_MANAGER_EXTENSION_WATCHER_H_ + +#include "base/scoped_multi_source_observation.h" +#include "base/scoped_observation.h" +#include "chrome/browser/profiles/profile_manager.h" +#include "chrome/browser/profiles/profile_manager_observer.h" +#include "extensions/browser/process_manager.h" +#include "extensions/browser/process_manager_observer.h" + +namespace performance_manager { + +// Sets the `PageType::kExtension` type on `PageNodes` hosting extension +// background pages. +class ExtensionWatcher : public ProfileManagerObserver, + public extensions::ProcessManagerObserver { + public: + ExtensionWatcher(); + ~ExtensionWatcher() override; + + private: + // ProfileManagerObserver: + void OnProfileAdded(Profile* profile) override; + + // extensions::ProcessManagerObserver: + void OnBackgroundHostCreated(extensions::ExtensionHost* host) override; + void OnProcessManagerShutdown(extensions::ProcessManager* manager) override; + + base::ScopedObservation<ProfileManager, ProfileManagerObserver> + profile_manager_observation_{this}; + + base::ScopedMultiSourceObservation<extensions::ProcessManager, + extensions::ProcessManagerObserver> + extension_process_manager_observation_{this}; +}; + +} // namespace performance_manager + +#endif // CHROME_BROWSER_PERFORMANCE_MANAGER_EXTENSION_WATCHER_H_
diff --git a/chrome/browser/performance_manager/mechanisms/page_loader.cc b/chrome/browser/performance_manager/mechanisms/page_loader.cc index 6219c3f..d2a3955 100644 --- a/chrome/browser/performance_manager/mechanisms/page_loader.cc +++ b/chrome/browser/performance_manager/mechanisms/page_loader.cc
@@ -4,7 +4,6 @@ #include "chrome/browser/performance_manager/mechanisms/page_loader.h" -#include "components/performance_manager/public/decorators/tab_properties_decorator.h" #include "components/performance_manager/public/graph/page_node.h" #include "components/performance_manager/public/web_contents_proxy.h" #include "content/public/browser/browser_task_traits.h" @@ -31,7 +30,7 @@ void PageLoader::LoadPageNode(const PageNode* page_node) { DCHECK(page_node); - DCHECK(TabPropertiesDecorator::Data::FromPageNode(page_node)->IsInTabStrip()); + DCHECK_EQ(page_node->GetType(), PageType::kTab); content::GetUIThreadTaskRunner({})->PostTask( FROM_HERE, base::BindOnce(&LoadPageOnUIThread, page_node->GetContentsProxy()));
diff --git a/chrome/browser/performance_manager/page_node_browsertest.cc b/chrome/browser/performance_manager/page_node_browsertest.cc new file mode 100644 index 0000000..7ebf604 --- /dev/null +++ b/chrome/browser/performance_manager/page_node_browsertest.cc
@@ -0,0 +1,64 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "base/test/bind.h" +#include "chrome/browser/extensions/extension_browsertest.h" +#include "chrome/browser/ui/browser.h" +#include "chrome/browser/ui/tabs/tab_strip_model.h" +#include "components/performance_manager/graph/page_node_impl.h" +#include "components/performance_manager/public/graph/page_node.h" +#include "components/performance_manager/public/performance_manager.h" +#include "content/public/test/browser_test.h" + +namespace performance_manager { + +namespace { + +using PageNodeBrowserTest = extensions::ExtensionBrowserTest; + +void ExpectPageType(base::WeakPtr<PageNode> page_node, PageType expected_type) { + base::RunLoop run_loop; + PerformanceManager::CallOnGraph(FROM_HERE, base::BindLambdaForTesting([&]() { + EXPECT_EQ(page_node->GetType(), + expected_type); + run_loop.Quit(); + })); + run_loop.Run(); +} + +} // namespace + +// Integration test verifying that the correct type is set for a PageNode +// associated with a tab. +IN_PROC_BROWSER_TEST_F(PageNodeBrowserTest, TypeTab) { + EXPECT_EQ(1, browser()->tab_strip_model()->count()); + + base::WeakPtr<PageNode> page_node = + PerformanceManager::GetPrimaryPageNodeForWebContents( + browser()->tab_strip_model()->GetActiveWebContents()); + + ExpectPageType(page_node, PageType::kTab); +} + +// Integration test verifying that the correct type is set for a PageNode +// associated with an extension background page. +IN_PROC_BROWSER_TEST_F(PageNodeBrowserTest, TypeExtension) { + ASSERT_TRUE(embedded_test_server()->Start()); + + const extensions::Extension* extension = LoadExtension( + test_data_dir_.AppendASCII("api_test/browser_action/basics")); + ASSERT_TRUE(extension); + extensions::ExtensionHost* host = + extensions::ProcessManager::Get(profile())->GetBackgroundHostForExtension( + extension->id()); + ASSERT_TRUE(host); + ASSERT_TRUE(host->host_contents()); + + base::WeakPtr<PageNode> page_node = + PerformanceManager::GetPrimaryPageNodeForWebContents( + host->host_contents()); + ExpectPageType(page_node, PageType::kExtension); +} + +} // namespace performance_manager
diff --git a/chrome/browser/performance_manager/policies/background_tab_loading_policy.cc b/chrome/browser/performance_manager/policies/background_tab_loading_policy.cc index 8c7f556..a2dacb0 100644 --- a/chrome/browser/performance_manager/policies/background_tab_loading_policy.cc +++ b/chrome/browser/performance_manager/policies/background_tab_loading_policy.cc
@@ -15,7 +15,6 @@ #include "chrome/browser/profiles/profile.h" #include "components/performance_manager/graph/page_node_impl.h" #include "components/performance_manager/public/decorators/site_data_recorder.h" -#include "components/performance_manager/public/decorators/tab_properties_decorator.h" #include "components/performance_manager/public/graph/frame_node.h" #include "components/performance_manager/public/graph/node_data_describer_registry.h" #include "components/performance_manager/public/graph/policies/background_tab_loading_policy.h" @@ -203,8 +202,7 @@ PageNode* page_node = page_node_and_permission.page_node.get(); if (page_node) { DCHECK(!FindPageNodeToLoadData(page_node)); - DCHECK(TabPropertiesDecorator::Data::FromPageNode(page_node) - ->IsInTabStrip()); + DCHECK_EQ(page_node->GetType(), PageType::kTab); page_nodes_to_load_.push_back( std::make_unique<PageNodeToLoadData>(page_node));
diff --git a/chrome/browser/performance_manager/policies/background_tab_loading_policy_unittest.cc b/chrome/browser/performance_manager/policies/background_tab_loading_policy_unittest.cc index 3fb0429..aec4b108 100644 --- a/chrome/browser/performance_manager/policies/background_tab_loading_policy_unittest.cc +++ b/chrome/browser/performance_manager/policies/background_tab_loading_policy_unittest.cc
@@ -13,7 +13,6 @@ #include "chrome/browser/performance_manager/mechanisms/page_loader.h" #include "components/performance_manager/graph/graph_impl.h" #include "components/performance_manager/graph/page_node_impl.h" -#include "components/performance_manager/public/decorators/tab_properties_decorator.h" #include "components/performance_manager/public/persistence/site_data/feature_usage.h" #include "components/performance_manager/public/persistence/site_data/site_data_reader.h" #include "components/performance_manager/test_support/graph_test_harness.h" @@ -164,10 +163,9 @@ to_load.push_back(page_node_and_notification_permission); EXPECT_CALL(*loader(), LoadPageNode(to_load.back().page_node.get())); - // Set |is_tab| property as this is a requirement to pass the PageNode to + // Mark the PageNode as a tab as this is a requirement to pass it to // ScheduleLoadForRestoredTabs(). - TabPropertiesDecorator::SetIsTabForTesting(to_load.back().page_node.get(), - true); + page_nodes.back()->SetType(PageType::kTab); } policy()->ScheduleLoadForRestoredTabs(to_load); @@ -189,10 +187,9 @@ to_load.push_back(page_node_and_notification_permission); EXPECT_CALL(*loader(), LoadPageNode(to_load.back().page_node.get())); - // Set |is_tab| property as this is a requirement to pass the PageNode to + // Mark the PageNode as a tab as this is a requirement to pass it to // ScheduleLoadForRestoredTabs(). - TabPropertiesDecorator::SetIsTabForTesting(to_load.back().page_node.get(), - true); + page_nodes.back()->SetType(PageType::kTab); } policy()->ScheduleLoadForRestoredTabs(to_load); @@ -213,10 +210,9 @@ page_nodes.back().get()->GetWeakPtr(), false); to_load.push_back(page_node_and_notification_permission); - // Set |is_tab| property as this is a requirement to pass the PageNode to + // Mark the PageNode as a tab as this is a requirement to pass it to // ScheduleLoadForRestoredTabs(). - TabPropertiesDecorator::SetIsTabForTesting(to_load.back().page_node.get(), - true); + page_nodes.back()->SetType(PageType::kTab); } PageNodeImpl* page_node_impl = page_nodes[0].get(); @@ -254,9 +250,9 @@ page_node_and_notification_permission_to_load_vector{ page_node_and_notification_permission}; - // Set |is_tab| property as this is a requirement to pass the PageNode to + // Mark the PageNode as a tab as this is a requirement to pass it to // ScheduleLoadForRestoredTabs(). - TabPropertiesDecorator::SetIsTabForTesting(page_node.get(), true); + page_node->SetType(PageType::kTab); EXPECT_CALL(*loader(), LoadPageNode(page_node.get())); policy()->ScheduleLoadForRestoredTabs( @@ -461,11 +457,10 @@ to_load.push_back(notification); - for (auto page_node_and_permission : to_load) { - // Set |is_tab| property as this is a requirement to pass the PageNode - // to ScheduleLoadForRestoredTabs(). - TabPropertiesDecorator::SetIsTabForTesting( - page_node_and_permission.page_node.get(), true); + for (auto& page_node : page_nodes) { + // Mark the PageNode as a tab as this is a requirement to pass it to + // ScheduleLoadForRestoredTabs(). + page_node->SetType(PageType::kTab); } // Test that tabs are loaded in the expected order: @@ -505,10 +500,9 @@ page_nodes.back().get()->GetWeakPtr(), false); to_load.push_back(page_node_and_permisssion); - // Set |is_tab| property as this is a requirement to pass the PageNode to + // Mark the PageNode as a tab as this is a requirement to pass it to // ScheduleLoadForRestoredTabs(). - TabPropertiesDecorator::SetIsTabForTesting(to_load.back().page_node.get(), - true); + page_nodes.back()->SetType(PageType::kTab); } // Use 1 loading slot so only one PageNode loads at a time. policy()->SetMaxSimultaneousLoadsForTesting(1);
diff --git a/chrome/browser/performance_manager/policies/high_efficiency_mode_policy.cc b/chrome/browser/performance_manager/policies/high_efficiency_mode_policy.cc index 05bdaf5b..7bfd7ca 100644 --- a/chrome/browser/performance_manager/policies/high_efficiency_mode_policy.cc +++ b/chrome/browser/performance_manager/policies/high_efficiency_mode_policy.cc
@@ -6,7 +6,6 @@ #include "base/containers/contains.h" #include "chrome/browser/performance_manager/policies/page_discarding_helper.h" -#include "components/performance_manager/public/decorators/tab_properties_decorator.h" #include "components/performance_manager/public/features.h" namespace performance_manager::policies { @@ -48,7 +47,7 @@ void HighEfficiencyModePolicy::OnBeforePageNodeRemoved( const PageNode* page_node) { - if (!TabPropertiesDecorator::Data::FromPageNode(page_node)->IsInTabStrip()) { + if (page_node->GetType() != PageType::kTab) { DCHECK(!base::Contains(active_discard_timers_, page_node)); return; } @@ -57,7 +56,7 @@ } void HighEfficiencyModePolicy::OnIsVisibleChanged(const PageNode* page_node) { - if (!TabPropertiesDecorator::Data::FromPageNode(page_node)->IsInTabStrip()) + if (page_node->GetType() != PageType::kTab) return; // If the page is made visible, any existing timers that refer to it should be @@ -91,9 +90,7 @@ if (high_efficiency_mode_enabled_) { DCHECK(active_discard_timers_.empty()); for (const PageNode* page_node : graph_->GetAllPageNodes()) { - if (TabPropertiesDecorator::Data::FromPageNode(page_node) - ->IsInTabStrip() && - !page_node->IsVisible()) { + if (page_node->GetType() == PageType::kTab && !page_node->IsVisible()) { base::TimeDelta time_before_discard = time_before_discard_ - page_node->GetTimeSinceLastVisibilityChange();
diff --git a/chrome/browser/performance_manager/policies/high_efficiency_mode_policy_unittest.cc b/chrome/browser/performance_manager/policies/high_efficiency_mode_policy_unittest.cc index 26e943f7..4ee06fd 100644 --- a/chrome/browser/performance_manager/policies/high_efficiency_mode_policy_unittest.cc +++ b/chrome/browser/performance_manager/policies/high_efficiency_mode_policy_unittest.cc
@@ -9,7 +9,6 @@ #include "base/time/time.h" #include "chrome/browser/performance_manager/policies/high_efficiency_mode_policy_helper.h" #include "chrome/browser/performance_manager/test_support/page_discarding_utils.h" -#include "components/performance_manager/public/decorators/tab_properties_decorator.h" #include "components/performance_manager/public/features.h" #include "components/performance_manager/public/user_tuning/prefs.h" #include "components/prefs/testing_pref_service.h" @@ -29,7 +28,7 @@ policy_ = policy.get(); graph()->PassToGraph(std::move(policy)); - TabPropertiesDecorator::SetIsTabForTesting(page_node(), true); + page_node()->SetType(PageType::kTab); } void TearDown() override { @@ -67,7 +66,7 @@ } TEST_F(HighEfficiencyModeTest, DontDiscardIfPageIsNotATab) { - TabPropertiesDecorator::SetIsTabForTesting(page_node(), false); + page_node()->SetType(PageType::kUnknown); policy()->OnHighEfficiencyModeChanged(true); page_node()->SetIsVisible(true); page_node()->SetIsVisible(false);
diff --git a/chrome/browser/performance_manager/tab_properties_decorator_browsertest.cc b/chrome/browser/performance_manager/tab_properties_decorator_browsertest.cc deleted file mode 100644 index 55bbc2d..0000000 --- a/chrome/browser/performance_manager/tab_properties_decorator_browsertest.cc +++ /dev/null
@@ -1,41 +0,0 @@ -// Copyright 2020 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "components/performance_manager/public/decorators/tab_properties_decorator.h" - -#include "base/test/bind.h" -#include "chrome/browser/ui/browser.h" -#include "chrome/browser/ui/tabs/tab_strip_model.h" -#include "chrome/test/base/in_process_browser_test.h" -#include "components/performance_manager/graph/page_node_impl.h" -#include "components/performance_manager/public/performance_manager.h" -#include "content/public/test/browser_test.h" - -namespace performance_manager { - -using TabPropertiesDecoratorBrowserTest = InProcessBrowserTest; - -// Integration test verifying that when a PageNode is created for a tab, the -// corresponding tab properties is set. -IN_PROC_BROWSER_TEST_F(TabPropertiesDecoratorBrowserTest, SetIsTab) { - EXPECT_EQ(1, browser()->tab_strip_model()->count()); - - // Get PageNode associated with the current tab. - base::WeakPtr<PageNode> page_node = - PerformanceManager::GetPrimaryPageNodeForWebContents( - browser()->tab_strip_model()->GetActiveWebContents()); - - // Get data from the PageNode and verify the tab properties. - base::RunLoop run_loop; - auto call_on_graph_cb = base::BindLambdaForTesting([&]() { - EXPECT_TRUE(page_node); - EXPECT_TRUE(TabPropertiesDecorator::Data::FromPageNode(page_node.get()) - ->IsInTabStrip()); - run_loop.Quit(); - }); - PerformanceManager::CallOnGraph(FROM_HERE, call_on_graph_cb); - run_loop.Run(); -} - -} // namespace performance_manager
diff --git a/chrome/browser/resources/BUILD.gn b/chrome/browser/resources/BUILD.gn index 9ac223a4..13e2db0 100644 --- a/chrome/browser/resources/BUILD.gn +++ b/chrome/browser/resources/BUILD.gn
@@ -135,9 +135,6 @@ "usb_internals:closure_compile", ] } - if (is_win || is_android || is_linux || is_chromeos) { - deps += [ "sandbox_internals:closure_compile" ] - } if (is_chromeos_ash) { deps += [ "chromeos:closure_compile",
diff --git a/chrome/browser/resources/sandbox_internals/BUILD.gn b/chrome/browser/resources/sandbox_internals/BUILD.gn index 7fdeecb..e00a9658 100644 --- a/chrome/browser/resources/sandbox_internals/BUILD.gn +++ b/chrome/browser/resources/sandbox_internals/BUILD.gn
@@ -2,33 +2,34 @@ # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. -import("//third_party/closure_compiler/compile_js.gni") +import("//tools/grit/preprocess_if_expr.gni") +import("//tools/typescript/ts_library.gni") -js_type_check("closure_compile") { - if (is_win) { - deps = [ ":sandbox_internals_win" ] +assert(is_android || is_linux || is_chromeos || is_win) + +preprocess_folder = "preprocessed" + +if (is_win) { + ts_files = [ "sandbox_internals_win.ts" ] +} else { + ts_files = [ "sandbox_internals.ts" ] +} + +preprocess_if_expr("preprocess") { + in_folder = "." + out_folder = "$target_gen_dir/$preprocess_folder" + in_files = ts_files +} + +ts_library("build_ts") { + root_dir = "$target_gen_dir/$preprocess_folder" + out_dir = "$target_gen_dir/tsc" + in_files = ts_files + + if (is_android) { + definitions = [ "./sandbox_android.d.ts" ] } - if (is_android || is_linux || is_chromeos) { - deps = [ ":sandbox_internals" ] - } -} -js_library("sandbox_internals") { - # Android & Linux both need _externs for type checks as they share a js file. - deps = [ - ":sandbox_android_externs", - "//ui/webui/resources/js:load_time_data.m", - "//ui/webui/resources/js:util.m", - ] -} - -js_library("sandbox_android_externs") { -} - -js_library("sandbox_internals_win") { - deps = [ - "//ui/webui/resources/js:assert.m", - "//ui/webui/resources/js:cr.m", - "//ui/webui/resources/js:util.m", - ] + deps = [ "//ui/webui/resources:library" ] + extra_deps = [ ":preprocess" ] }
diff --git a/chrome/browser/resources/sandbox_internals/sandbox_android.d.ts b/chrome/browser/resources/sandbox_internals/sandbox_android.d.ts new file mode 100644 index 0000000..90ac691 --- /dev/null +++ b/chrome/browser/resources/sandbox_internals/sandbox_android.d.ts
@@ -0,0 +1,21 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +/** @fileoverview Definitions for chrome.getAndroidSandboxStatus */ + +declare namespace chrome { + type AndroidSandboxStatus = { + androidBuildId: string, + pid: string, + procStatus: string, + seccompStatus: number, + secontext: string, + uid: string, + }; + + type GetAndroidStatusCallback = (status: AndroidSandboxStatus) => void; + + // This function is only exposed to the Android chrome://sandbox webui. + function getAndroidSandboxStatus(callback: GetAndroidStatusCallback): void; +}
diff --git a/chrome/browser/resources/sandbox_internals/sandbox_android_externs.js b/chrome/browser/resources/sandbox_internals/sandbox_android_externs.js deleted file mode 100644 index c426f215..0000000 --- a/chrome/browser/resources/sandbox_internals/sandbox_android_externs.js +++ /dev/null
@@ -1,21 +0,0 @@ -// Copyright 2019 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -/** - * @typedef {{ - * seccompStatus: number, - * pid: string, - * uid: string, - * secontext: string, - * procStatus: string, - * androidBuildId: string - * }} - */ -let AndroidSandboxStatus; - -/** - * This function is only exposed to the Android chrome://sandbox webui. - * @param {!function(!AndroidSandboxStatus)=} callback - */ -chrome.getAndroidSandboxStatus = function(callback) {};
diff --git a/chrome/browser/resources/sandbox_internals/sandbox_internals.js b/chrome/browser/resources/sandbox_internals/sandbox_internals.ts similarity index 87% rename from chrome/browser/resources/sandbox_internals/sandbox_internals.js rename to chrome/browser/resources/sandbox_internals/sandbox_internals.ts index 26c99c9..f146849 100644 --- a/chrome/browser/resources/sandbox_internals/sandbox_internals.js +++ b/chrome/browser/resources/sandbox_internals/sandbox_internals.ts
@@ -2,32 +2,33 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -import {$} from 'chrome://resources/js/util.m.js'; - // <if expr="is_linux or chromeos_ash or chromeos_lacros"> import './strings.m.js'; + import {loadTimeData} from 'chrome://resources/js/load_time_data.m.js'; // </if> +import {$} from 'chrome://resources/js/util.m.js'; + /** * CSS classes for different statuses. - * @enum {string} */ -const StatusClass = { - GOOD: 'good', - BAD: 'bad', - MEDIUM: 'medium', - INFO: 'info' -}; +enum StatusClass { + GOOD = 'good', + BAD = 'bad', + MEDIUM = 'medium', + INFO = 'info', +} /** * Adds a row to the sandbox status table. - * @param {string} name The name of the status item. - * @param {string} value The status of the item. - * @param {string?} cssClass A CSS class to apply to the row. - * @return {Element} The newly added TR. + * @param name The name of the status item. + * @param value The status of the item. + * @param cssClass A CSS class to apply to the row. + * @return The newly added TR. */ -function addStatusRow(name, value, cssClass) { +function addStatusRow( + name: string, value: string, cssClass: StatusClass|null): HTMLElement { const row = document.createElement('tr'); const nameCol = row.appendChild(document.createElement('td')); @@ -46,21 +47,9 @@ } /** - * Adds a status row that reports either Yes or No. - * @param {string} name The name of the status item. - * @param {boolean} result The status (good/bad) result. - * @return {Element} The newly added TR. - */ -function addGoodBadRow(name, result) { - return addStatusRow( - name, result ? 'Yes' : 'No', result ? StatusClass.GOOD : StatusClass.BAD); -} - -/** * Reports the overall sandbox status evaluation message. - * @param {boolean} result */ -function setEvaluation(result) { +function setEvaluation(result: boolean) { const message = result ? 'You are adequately sandboxed.' : 'You are NOT adequately sandboxed.'; $('evaluation').innerText = message; @@ -71,7 +60,7 @@ * Main page handler for Android. */ function androidHandler() { - chrome.getAndroidSandboxStatus((status) => { + chrome.getAndroidSandboxStatus(status => { let isIsolated = false; let isTsync = false; let isChromeSeccomp = false; @@ -86,7 +75,7 @@ const procStatus = status.procStatus.split('\n'); for (const line of procStatus) { if (line.startsWith('Seccomp')) { - let value = line.split(':')[1].trim(); + let value = line.split(':')[1]!.trim(); let cssClass = StatusClass.BAD; if (value === '2') { value = 'Yes - TSYNC (' + line + ')'; @@ -133,6 +122,18 @@ // </if> // <if expr="is_linux or chromeos_ash or chromeos_lacros"> + +/** + * Adds a status row that reports either Yes or No. + * @param name The name of the status item. + * @param result The status (good/bad) result. + * @return The newly added TR. + */ +function addGoodBadRow(name: string, result: boolean): HTMLElement { + return addStatusRow( + name, result ? 'Yes' : 'No', result ? StatusClass.GOOD : StatusClass.BAD); +} + /** * Main page handler for desktop Linux. */
diff --git a/chrome/browser/resources/sandbox_internals/sandbox_internals_win.js b/chrome/browser/resources/sandbox_internals/sandbox_internals_win.ts similarity index 77% rename from chrome/browser/resources/sandbox_internals/sandbox_internals_win.js rename to chrome/browser/resources/sandbox_internals/sandbox_internals_win.ts index 711696ce..db836ac 100644 --- a/chrome/browser/resources/sandbox_internals/sandbox_internals_win.js +++ b/chrome/browser/resources/sandbox_internals/sandbox_internals_win.ts
@@ -2,98 +2,86 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -import {assert, assertNotReached} from 'chrome://resources/js/assert.m.js'; +import {assert} from 'chrome://resources/js/assert_ts.js'; import {sendWithPromise} from 'chrome://resources/js/cr.m.js'; import {$} from 'chrome://resources/js/util.m.js'; -/** - * @typedef {{ - * name: string, - * enabled: boolean, - * }} - */ -let SandboxFeature; +type SandboxFeature = { + name: string, + enabled: boolean, +}; -/** - * @typedef {{ - * processId: number, - * processType: string, - * name: string, - * metricsName: string, - * sandboxType: string - * }} - */ -let BrowserHostProcess; +type BrowserHostProcess = { + processId: number, + processType: string, + name: string, + metricsName: string, + sandboxType: string, +}; -/** - * @typedef {{ - * processId: number - * }} - */ -let RendererHostProcess; +type RendererHostProcess = { + processId: number, +}; /** * This may have additional fields displayed in the JSON output. * See //sandbox/win/src/sandbox_constants.cc for keys in policy. - * @typedef {{ - * processId: number, - * lockdownLevel: string, - * desiredIntegrityLevel: string, - * platformMitigations: string, - * componentFilters: string - * }} */ -let PolicyDiagnostic; +type PolicyDiagnostic = { + appContainerCapabilities: string[], + appContainerSid: string, + componentFilters: string, + desiredIntegrityLevel: string, + lockdownLevel: string, + lowboxSid: string, + platformMitigations: string, + processId: number, +}; -/** - * @typedef {{ - * browser: !Array<!BrowserHostProcess>, - * renderer: !Array<!RendererHostProcess>, - * policies: !Array<!PolicyDiagnostic>, - * features: !Array<!SandboxFeature> - * }} - */ -let SandboxDiagnostics; +type SandboxDiagnostics = { + browser: BrowserHostProcess[], + renderer: RendererHostProcess[], + policies: PolicyDiagnostic[], + features: SandboxFeature[], +}; /** * Represents a mitigation field from the PROCESS_CREATION_MITITAGION_POLICY* * series in Winbase.h. */ -class MitigationField { +abstract class MitigationField { + mitigation: string; + value: number; + mask: number; + offset: number; + /** * mask & value must be 0<=x<=255. - * @param {string} mitigation human name of mitigation. - * @param {number} value value to match within mask. - * @param {number} mask applied before matching. - * @param {number} offset within PC section. + * @param mitigation human name of mitigation. + * @param value value to match within mask. + * @param mask applied before matching. + * @param offset within PC section. */ - constructor(mitigation, value, mask, offset) { - /** @type {string} */ + constructor(mitigation: string, value: number, mask: number, offset: number) { this.mitigation = mitigation; - /** @type {number} */ this.value = value; - /** @type {number} */ this.mask = mask; - /** @type {number} */ this.offset = offset; } /** * Each PC field overrides this as they know where their data is. - * @param {Uint8Array} bytes platform mitigations data. - * @return {Uint8Array} chunk containing this field or null. + * @param bytes platform mitigations data. + * @return chunk containing this field or null. */ - getFieldData(bytes) { - assertNotReached(); - } + abstract getFieldData(bytes: Uint8Array): Uint8Array|null; /** * Are all the bits of this field set in the mitigations represented by * |bytes|. - * @param {Uint8Array} bytes platform mitigations. - * @return {boolean} + * @param bytes platform mitigations. */ - isFieldSet(bytes) { + isFieldSet(bytes: Uint8Array): boolean { if (bytes.length !== 4 && bytes.length !== 8 && bytes.length !== 16) { throw new Error('Platform mitigations has unexpected size'); } @@ -103,7 +91,7 @@ } const idx = subfield.length - 1 - Math.floor(this.offset / 8); const ibit = this.offset % 8; - return (subfield[idx] & (this.mask << ibit)) === (this.value << ibit); + return (subfield[idx]! & (this.mask << ibit)) === (this.value << ibit); } } @@ -112,10 +100,10 @@ */ class PC0Field extends MitigationField { /** - * @param {Uint8Array} bytes platform mitigations data. - * @return {Uint8Array} chunk containing this field or null. + * @param bytes platform mitigations data. + * @return chunk containing this field or null. */ - getFieldData(bytes) { + getFieldData(bytes: Uint8Array): Uint8Array { if (bytes.length === 4) { // Win32 only 4 bytes of fields. return bytes; @@ -131,8 +119,7 @@ * PROCESS_CREATION_MITIGATION_POLICY_* */ class PC1Field extends MitigationField { - /** @override */ - getFieldData(bytes) { + getFieldData(bytes: Uint8Array) { if (bytes.length === 8) { return bytes; } else if (bytes.length === 16) { @@ -146,8 +133,7 @@ * PROCESS_CREATION_MITIGATION_POLICY2_* */ class PC2Field extends MitigationField { - /** @override */ - getFieldData(bytes) { + getFieldData(bytes: Uint8Array) { if (bytes.length === 8) { return null; } else if (bytes.length === 16) { @@ -162,8 +148,9 @@ representation of PROCESS_CREATION_MITIGATION_POLICY_* entries. */ class DecodeMitigations { + fields: MitigationField[]; + constructor() { - /* @typedef {{Array<MitigationField>}} */ this.fields = [ // Defined in Windows.h from Winbase.h // basic (pc0) mitigations in {win7},{lsb of pc1}. @@ -262,10 +249,10 @@ } /** - * @param {string} str Hex encoded data. - * @return {Uint8Array} bytes Decoded bytes. + * @param str Hex encoded data. + * @return bytes Decoded bytes. */ - parseHexString(str) { + parseHexString(str: string): Uint8Array { assert((str.length % 2 === 0), 'str must have even length'); const bytes = new Uint8Array(str.length / 2); for (let idx = 0; idx < str.length / 2; idx++) { @@ -277,10 +264,10 @@ /** * Return a list of platform mitigation which are set in |mitigations|. * Mitigations will be in the same order as Winbase.h. - * @param {string} mitigations Hex encoded process mitigation flags. - * @return {!Array<string>} Matched mitigation values. + * @param mitigations Hex encoded process mitigation flags. + * @return Matched mitigation values. */ - enabledMitigations(mitigations) { + enabledMitigations(mitigations: string): string[] { const bytes = this.parseHexString(mitigations); const output = []; for (const item of this.fields) { @@ -294,7 +281,7 @@ const DECODE_MITIGATIONS = new DecodeMitigations(); -const WELL_KNOWN_SIDS = { +const WELL_KNOWN_SIDS: {[sid: string]: string} = { 'S-1-15-3-1': 'InternetClient', 'S-1-15-3-2': 'InternetClientServer', 'S-1-15-3-3': 'PrivateNetworkClientServer', @@ -337,21 +324,18 @@ /** * Maps capabilities to well known values. - * @param {string} - * @return {string} */ -function mapCapabilitySid(sid) { +function mapCapabilitySid(sid: string): string { if (WELL_KNOWN_SIDS[sid]) { - return WELL_KNOWN_SIDS[sid]; + return WELL_KNOWN_SIDS[sid]!; } return sid; } /** * Adds a row to the sandbox-status table. - * @param {!Array<Node>} args */ -function addRow(args) { +function addRow(args: Node[]) { const row = document.createElement('tr'); for (const td of args) { row.appendChild(td); @@ -361,10 +345,8 @@ /** * Makes a <td> containing arg as textContent. - * @param {string} textContent - * @return {Node} */ -function makeTextEntry(textContent) { +function makeTextEntry(textContent: string): Node { const col = document.createElement('td'); col.textContent = textContent; return col; @@ -372,10 +354,8 @@ /** * Makes a <td> containing formatted component filter flags. - * @param {PolicyDiagnostic} policy - * @return {Node} */ -function makeComponentFilterEntry(policy) { +function makeComponentFilterEntry(policy: PolicyDiagnostic): Node { const fixed = document.createElement('div'); fixed.classList.add('mitigations'); fixed.innerText = policy.componentFilters; @@ -386,11 +366,8 @@ /** * Makes an expandable <td> containing arg as textContent. - * @param {string} mainEntry is always shown - * @param {Object} expandable - * @return {Node} */ -function makeExpandableEntry(mainEntry, expandable) { +function makeExpandableEntry(mainEntry: string, expandable: Expandable): Node { const button = document.createElement('div'); const expand = document.createElement('div'); button.innerText = '\u2795'; // (+) @@ -413,41 +390,66 @@ return col; } +abstract class Expandable { + expanded: boolean = false; + + onClick(col: HTMLElement): boolean { + this.expanded = !this.expanded; + col.innerText = this.getText(); + return this.expanded; + } + + abstract getText(): string; +} + +class MitigationEntryExpandable extends Expandable { + mitigations: string; + + constructor(mitigations: string) { + super(); + this.mitigations = mitigations; + } + + override getText(): string { + if (this.expanded) { + return DECODE_MITIGATIONS.enabledMitigations(this.mitigations).join('\n'); + } else { + return ''; + } + } +} + +class AppContainerEntryExpandable extends Expandable { + caps: string[]; + + constructor(caps: string[]) { + super(); + this.caps = caps; + } + + override getText(): string { + if (this.expanded) { + return this.caps.map(mapCapabilitySid).sort().join('\n'); + } else { + return ''; + } + } +} + /** * Adds a mitigations entry that can expand to show friendly names of the * mitigations. - * @param {string} platformMitigations - * @return {Node} - * @suppress {globalThis} */ -function makeMitigationEntry(platformMitigations) { - const expander = { - expanded: false, - mitigations: platformMitigations, - onClick: function(col) { - this.expanded = !this.expanded; - col.innerText = this.getText(); - return this.expanded; - }, - getText: function() { - if (this.expanded) { - return DECODE_MITIGATIONS.enabledMitigations(this.mitigations) - .join('\n'); - } else { - return ''; - } - } - }; +function makeMitigationEntry(platformMitigations: string): Node { + const expander = new MitigationEntryExpandable(platformMitigations); return makeExpandableEntry(platformMitigations, expander); } /** * Formats a lowbox sid or appcontainer configuration (policies can only * have one or the other). - * @param {PolicyDiagnostic} policy - * @return {Node} */ -function makeLowboxAcEntry(policy) { +function makeLowboxAcEntry(policy: PolicyDiagnostic): Node { if (policy.lowboxSid) { // Lowbox token does not have capabilities but should match AC entries. const fixed = document.createElement('div'); @@ -459,22 +461,8 @@ } if (policy.appContainerSid) { // AC has identifying SID plus lockdown capabilities. - const expander = { - expanded: false, - caps: policy.appContainerCapabilities, - onClick: function(col) { - this.expanded = !this.expanded; - col.innerText = this.getText(); - return this.expanded; - }, - getText: function() { - if (this.expanded) { - return this.caps.map(mapCapabilitySid).sort().join('\n'); - } else { - return ''; - } - } - }; + const expander = + new AppContainerEntryExpandable(policy.appContainerCapabilities); return makeExpandableEntry(policy.appContainerSid, expander); } return makeTextEntry(''); @@ -482,17 +470,14 @@ /** * Adds policy information for a process to the sandbox-status table. - * @param {number} pid - * @param {string} type - * @param {string} name - * @param {string} sandbox - * @param {PolicyDiagnostic} policy */ -function addRowForProcess(pid, type, name, sandbox, policy) { +function addRowForProcess( + pid: number, type: string, name: string, sandbox: string, + policy: PolicyDiagnostic) { if (policy) { // Text-only items. const entries = [ - pid, type, name, sandbox, policy.lockdownLevel, + String(pid), type, name, sandbox, policy.lockdownLevel, policy.desiredIntegrityLevel ].map(makeTextEntry); entries.push(makeMitigationEntry(policy.platformMitigations)); @@ -500,16 +485,14 @@ entries.push(makeLowboxAcEntry(policy)); addRow(entries); } else { - addRow([pid, type, name, 'Not Sandboxed', '', '', '', '', ''].map( + addRow([String(pid), type, name, 'Not Sandboxed', '', '', '', '', ''].map( makeTextEntry)); } } -/** @param {!SandboxDiagnostics} results */ -function onGetSandboxDiagnostics(results) { +function onGetSandboxDiagnostics(results: SandboxDiagnostics) { // Make it easy to look up policies. - /** @type {!Map<number,!PolicyDiagnostic>} */ - const policies = new Map(); + const policies: Map<number, PolicyDiagnostic> = new Map(); for (const policy of results.policies) { policies.set(policy.processId, policy); } @@ -525,13 +508,14 @@ const pid = process.processId; const name = process.name || process.metricsName; addRowForProcess( - pid, process.processType, name, process.sandboxType, policies.get(pid)); + pid, process.processType, name, process.sandboxType, policies.get(pid)! + ); } // Renderer Processes. for (const process of results.renderer) { const pid = process.processId; - addRowForProcess(pid, 'Renderer', '', 'Renderer', policies.get(pid)); + addRowForProcess(pid, 'Renderer', '', 'Renderer', policies.get(pid)!); } // Raw Diagnostics.
diff --git a/chrome/browser/share/share_history.cc b/chrome/browser/share/share_history.cc index 2eda4bf5..ae9d868 100644 --- a/chrome/browser/share/share_history.cc +++ b/chrome/browser/share/share_history.cc
@@ -35,6 +35,8 @@ // do not fold these constants together. const char* const kShareHistoryKey = "share_history"; +constexpr auto kMaxHistoryAge = base::Days(90); + int TodaysDay() { return (base::Time::Now() - base::Time::UnixEpoch()).InDays(); } @@ -200,7 +202,7 @@ init_finished_ = true; post_init_callbacks_.Notify(); - // TODO(ellyjones): Expire entries older than WINDOW days. + Clear(base::Time(), base::Time::Now() - kMaxHistoryAge); } void ShareHistory::FlushToBackingDb() {
diff --git a/chrome/browser/share/share_history_unittest.cc b/chrome/browser/share/share_history_unittest.cc index 97f951d..d915b739 100644 --- a/chrome/browser/share/share_history_unittest.cc +++ b/chrome/browser/share/share_history_unittest.cc
@@ -47,6 +47,17 @@ baz->set_count(1); } + { + // An old entry that will be expired when the history is loaded from the + // backing DB. + auto* long_ago = proto.mutable_day_histories()->Add(); + long_ago->set_day(DaysSinceUnixEpoch() - 365); + + auto* foo = long_ago->mutable_target_histories()->Add(); + foo->mutable_target()->set_component_name(kTarget0Name); + foo->set_count(2); + } + return proto; } @@ -244,4 +255,16 @@ } } +TEST_F(ShareHistoryTest, OldEntriesExpired) { + (*backing_entries())["share_history"] = BuildTestProto(); + Init(); + + auto result = GetFlatShareHistory(); + EXPECT_EQ(result[0].component_name, kTarget0Name); + + // There are 4 entries today, 1 day yesterday, and 2 entries a year ago; the + // latter should be expired on load. + EXPECT_EQ(result[0].count, 5); +} + } // namespace sharing
diff --git a/chrome/browser/ssl/https_first_mode_settings_tracker.cc b/chrome/browser/ssl/https_first_mode_settings_tracker.cc index bc90d229..ba79c93 100644 --- a/chrome/browser/ssl/https_first_mode_settings_tracker.cc +++ b/chrome/browser/ssl/https_first_mode_settings_tracker.cc
@@ -12,6 +12,7 @@ #include "components/keyed_service/content/browser_context_dependency_manager.h" #include "components/keyed_service/content/browser_context_keyed_service_factory.h" #include "components/prefs/pref_service.h" +#include "components/variations/synthetic_trials.h" #include "content/public/browser/browser_context.h" #if BUILDFLAG(IS_CHROMEOS_ASH) @@ -43,7 +44,8 @@ ChromeMetricsServiceAccessor::RegisterSyntheticFieldTrial( kHttpsFirstModeSyntheticFieldTrialName, enabled ? kHttpsFirstModeSyntheticFieldTrialEnabledGroup - : kHttpsFirstModeSyntheticFieldTrialDisabledGroup); + : kHttpsFirstModeSyntheticFieldTrialDisabledGroup, + variations::SyntheticTrialAnnotationMode::kCurrentLog); } HttpsFirstModeService::~HttpsFirstModeService() = default;
diff --git a/chrome/browser/ui/android/omnibox/java/src/org/chromium/chrome/browser/omnibox/suggestions/AutocompleteMediator.java b/chrome/browser/ui/android/omnibox/java/src/org/chromium/chrome/browser/omnibox/suggestions/AutocompleteMediator.java index 7d0b185..4aba9ff 100644 --- a/chrome/browser/ui/android/omnibox/java/src/org/chromium/chrome/browser/omnibox/suggestions/AutocompleteMediator.java +++ b/chrome/browser/ui/android/omnibox/java/src/org/chromium/chrome/browser/omnibox/suggestions/AutocompleteMediator.java
@@ -820,6 +820,12 @@ transition = PageTransition.LINK; } + // Kick off an action to clear focus and dismiss the suggestions list. + // This normally happens when the target site loads and focus is moved to the webcontents. + // On Android T we occasionally observe focus events to be lost, resulting with Suggestions + // list obscuring the view. + mDelegate.clearOmniboxFocus(); + if (suggestion.getType() == OmniboxSuggestionType.CLIPBOARD_IMAGE) { mDelegate.loadUrlWithPostData(url.getSpec(), transition, inputStart, suggestion.getPostContentType(), suggestion.getPostData());
diff --git a/chrome/browser/ui/android/toolbar/BUILD.gn b/chrome/browser/ui/android/toolbar/BUILD.gn index 0c36433..c7293ed 100644 --- a/chrome/browser/ui/android/toolbar/BUILD.gn +++ b/chrome/browser/ui/android/toolbar/BUILD.gn
@@ -317,8 +317,8 @@ "//components/browser_ui/settings/android:java", "//components/browser_ui/widget/android:java", "//components/feature_engagement:feature_engagement_java", - "//components/optimization_guide/proto:optimization_guide_proto_java", "//components/search_engines/android:java", + "//components/segmentation_platform/public/proto:segmentation_platform_proto_java", "//content/public/android:content_full_java", "//content/public/test/android:content_java_test_support", "//third_party/android_deps:guava_android_java",
diff --git a/chrome/browser/ui/android/toolbar/java/src/org/chromium/chrome/browser/toolbar/adaptive/AdaptiveToolbarStatePredictor.java b/chrome/browser/ui/android/toolbar/java/src/org/chromium/chrome/browser/toolbar/adaptive/AdaptiveToolbarStatePredictor.java index e2abebf..db6c38b 100644 --- a/chrome/browser/ui/android/toolbar/java/src/org/chromium/chrome/browser/toolbar/adaptive/AdaptiveToolbarStatePredictor.java +++ b/chrome/browser/ui/android/toolbar/java/src/org/chromium/chrome/browser/toolbar/adaptive/AdaptiveToolbarStatePredictor.java
@@ -14,8 +14,8 @@ import org.chromium.chrome.browser.profiles.Profile; import org.chromium.chrome.browser.segmentation_platform.SegmentationPlatformServiceFactory; import org.chromium.chrome.browser.toolbar.adaptive.AdaptiveToolbarFeatures.AdaptiveToolbarButtonVariant; -import org.chromium.components.optimization_guide.proto.ModelsProto.OptimizationTarget; import org.chromium.components.segmentation_platform.SegmentationPlatformService; +import org.chromium.components.segmentation_platform.proto.SegmentationProto.SegmentId; import org.chromium.ui.permissions.AndroidPermissionDelegate; /** @@ -191,8 +191,7 @@ segmentationPlatformService.getSelectedSegment( ADAPTIVE_TOOLBAR_SEGMENTATION_KEY, result -> { callback.onResult(new Pair<>(result.isReady, - getAdaptiveToolbarButtonVariantFromOptimizationTarget( - result.selectedSegment))); + getAdaptiveToolbarButtonVariantFromSegmentId(result.selectedSegment))); }); } @@ -218,13 +217,13 @@ } /** - * Conversion method between {@link OptimizationTarget} and {@link + * Conversion method between {@link SegmentId} and {@link * AdaptiveToolbarButtonVariant}. */ @VisibleForTesting - static @AdaptiveToolbarButtonVariant int getAdaptiveToolbarButtonVariantFromOptimizationTarget( - OptimizationTarget optimizationTarget) { - switch (optimizationTarget) { + static @AdaptiveToolbarButtonVariant int getAdaptiveToolbarButtonVariantFromSegmentId( + SegmentId segmentId) { + switch (segmentId) { case OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB: return AdaptiveToolbarButtonVariant.NEW_TAB; case OPTIMIZATION_TARGET_SEGMENTATION_SHARE:
diff --git a/chrome/browser/ui/android/toolbar/java/src/org/chromium/chrome/browser/toolbar/adaptive/AdaptiveToolbarStatePredictorTest.java b/chrome/browser/ui/android/toolbar/java/src/org/chromium/chrome/browser/toolbar/adaptive/AdaptiveToolbarStatePredictorTest.java index 85d110b..12cc92e 100644 --- a/chrome/browser/ui/android/toolbar/java/src/org/chromium/chrome/browser/toolbar/adaptive/AdaptiveToolbarStatePredictorTest.java +++ b/chrome/browser/ui/android/toolbar/java/src/org/chromium/chrome/browser/toolbar/adaptive/AdaptiveToolbarStatePredictorTest.java
@@ -31,7 +31,7 @@ import org.chromium.chrome.test.util.browser.Features; import org.chromium.chrome.test.util.browser.Features.DisableFeatures; import org.chromium.chrome.test.util.browser.Features.EnableFeatures; -import org.chromium.components.optimization_guide.proto.ModelsProto.OptimizationTarget; +import org.chromium.components.segmentation_platform.proto.SegmentationProto.SegmentId; import org.chromium.ui.permissions.AndroidPermissionDelegate; import java.util.HashMap; @@ -272,22 +272,22 @@ @Test @SmallTest - public void testOptimizationTargetToAdaptiveToolbarButtonVariantConversion() { + public void testSegmentIdToAdaptiveToolbarButtonVariantConversion() { Assert.assertEquals(AdaptiveToolbarButtonVariant.NEW_TAB, - AdaptiveToolbarStatePredictor.getAdaptiveToolbarButtonVariantFromOptimizationTarget( - OptimizationTarget.OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB)); + AdaptiveToolbarStatePredictor.getAdaptiveToolbarButtonVariantFromSegmentId( + SegmentId.OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB)); Assert.assertEquals(AdaptiveToolbarButtonVariant.SHARE, - AdaptiveToolbarStatePredictor.getAdaptiveToolbarButtonVariantFromOptimizationTarget( - OptimizationTarget.OPTIMIZATION_TARGET_SEGMENTATION_SHARE)); + AdaptiveToolbarStatePredictor.getAdaptiveToolbarButtonVariantFromSegmentId( + SegmentId.OPTIMIZATION_TARGET_SEGMENTATION_SHARE)); Assert.assertEquals(AdaptiveToolbarButtonVariant.VOICE, - AdaptiveToolbarStatePredictor.getAdaptiveToolbarButtonVariantFromOptimizationTarget( - OptimizationTarget.OPTIMIZATION_TARGET_SEGMENTATION_VOICE)); + AdaptiveToolbarStatePredictor.getAdaptiveToolbarButtonVariantFromSegmentId( + SegmentId.OPTIMIZATION_TARGET_SEGMENTATION_VOICE)); Assert.assertEquals(AdaptiveToolbarButtonVariant.UNKNOWN, - AdaptiveToolbarStatePredictor.getAdaptiveToolbarButtonVariantFromOptimizationTarget( - OptimizationTarget.OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD)); + AdaptiveToolbarStatePredictor.getAdaptiveToolbarButtonVariantFromSegmentId( + SegmentId.OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD)); Assert.assertEquals(AdaptiveToolbarButtonVariant.UNKNOWN, - AdaptiveToolbarStatePredictor.getAdaptiveToolbarButtonVariantFromOptimizationTarget( - OptimizationTarget.OPTIMIZATION_TARGET_UNKNOWN)); + AdaptiveToolbarStatePredictor.getAdaptiveToolbarButtonVariantFromSegmentId( + SegmentId.OPTIMIZATION_TARGET_UNKNOWN)); } private AdaptiveToolbarStatePredictor buildStatePredictor(boolean toolbarSettingsToggleEnabled,
diff --git a/chrome/browser/ui/ash/holding_space/holding_space_ui_browsertest.cc b/chrome/browser/ui/ash/holding_space/holding_space_ui_browsertest.cc index 41c4f919..753fbde 100644 --- a/chrome/browser/ui/ash/holding_space/holding_space_ui_browsertest.cc +++ b/chrome/browser/ui/ash/holding_space/holding_space_ui_browsertest.cc
@@ -20,6 +20,7 @@ #include "ash/public/cpp/holding_space/holding_space_test_api.h" #include "ash/public/cpp/holding_space/mock_holding_space_client.h" #include "ash/public/cpp/holding_space/mock_holding_space_model_observer.h" +#include "ash/style/ash_color_provider.h" #include "ash/test/view_drawn_waiter.h" #include "base/bind.h" #include "base/callback_helpers.h" @@ -2061,11 +2062,13 @@ EXPECT_TRUE(primary_label->GetVisible()); EXPECT_EQ(primary_label->GetText(), target_file_name); + const bool is_dark_mode_state = AshColorProvider::Get()->IsDarkModeEnabled(); // Initially, no bytes have been received so `secondary_label` should display // `0 B` as there is no knowledge of the total number of bytes expected. EXPECT_TRUE(secondary_label->GetVisible()); EXPECT_EQ(secondary_label->GetText(), u"0 B"); - EXPECT_EQ(secondary_label->GetEnabledColor(), gfx::kGoogleGrey400); + EXPECT_EQ(secondary_label->GetEnabledColor(), + is_dark_mode_state ? gfx::kGoogleGrey400 : gfx::kGoogleGrey700); // The accessible name should indicate that the download is in progress. EXPECT_EQ(GetAccessibleName(download_chips.at(0)), @@ -2083,7 +2086,8 @@ EXPECT_EQ(primary_label->GetText(), target_file_name); EXPECT_TRUE(secondary_label->GetVisible()); WaitForText(secondary_label, u"Paused, 0 B"); - EXPECT_EQ(secondary_label->GetEnabledColor(), gfx::kGoogleGrey400); + EXPECT_EQ(secondary_label->GetEnabledColor(), + is_dark_mode_state ? gfx::kGoogleGrey400 : gfx::kGoogleGrey700); // The accessible name should indicate that the download is in progress and // that progress is paused. @@ -2101,7 +2105,8 @@ EXPECT_EQ(primary_label->GetText(), target_file_name); EXPECT_TRUE(secondary_label->GetVisible()); WaitForText(secondary_label, u"Paused, 1,024 KB"); - EXPECT_EQ(secondary_label->GetEnabledColor(), gfx::kGoogleGrey400); + EXPECT_EQ(secondary_label->GetEnabledColor(), + is_dark_mode_state ? gfx::kGoogleGrey400 : gfx::kGoogleGrey700); // The accessible name should indicate that the download is in progress and // that progress is paused. @@ -2119,7 +2124,8 @@ EXPECT_EQ(primary_label->GetText(), target_file_name); EXPECT_TRUE(secondary_label->GetVisible()); WaitForText(secondary_label, u"1,024 KB"); - EXPECT_EQ(secondary_label->GetEnabledColor(), gfx::kGoogleGrey400); + EXPECT_EQ(secondary_label->GetEnabledColor(), + is_dark_mode_state ? gfx::kGoogleGrey400 : gfx::kGoogleGrey700); // The accessible name should indicate that the download is in progress. EXPECT_EQ(GetAccessibleName(download_chips.at(0)), @@ -2137,7 +2143,8 @@ EXPECT_EQ(primary_label->GetText(), target_file_name); EXPECT_TRUE(secondary_label->GetVisible()); WaitForText(secondary_label, u"1.0/2.0 MB"); - EXPECT_EQ(secondary_label->GetEnabledColor(), gfx::kGoogleGrey400); + EXPECT_EQ(secondary_label->GetEnabledColor(), + is_dark_mode_state ? gfx::kGoogleGrey400 : gfx::kGoogleGrey700); // The accessible name should indicate that the download is in progress. EXPECT_EQ(GetAccessibleName(download_chips.at(0)), @@ -2155,7 +2162,8 @@ EXPECT_EQ(primary_label->GetText(), target_file_name); EXPECT_TRUE(secondary_label->GetVisible()); WaitForText(secondary_label, u"Paused, 1.0/2.0 MB"); - EXPECT_EQ(secondary_label->GetEnabledColor(), gfx::kGoogleGrey400); + EXPECT_EQ(secondary_label->GetEnabledColor(), + is_dark_mode_state ? gfx::kGoogleGrey400 : gfx::kGoogleGrey700); // The accessible name should indicate that the download is in progress and // that progress is paused. @@ -2176,7 +2184,8 @@ EXPECT_EQ(primary_label->GetText(), target_file_name); EXPECT_TRUE(secondary_label->GetVisible()); WaitForText(secondary_label, u"Paused, 2.0/2.0 MB"); - EXPECT_EQ(secondary_label->GetEnabledColor(), gfx::kGoogleGrey400); + EXPECT_EQ(secondary_label->GetEnabledColor(), + is_dark_mode_state ? gfx::kGoogleGrey400 : gfx::kGoogleGrey700); // The accessible name should indicate that the download is in progress and // that progress is paused. @@ -2196,7 +2205,8 @@ EXPECT_EQ(primary_label->GetText(), target_file_name); EXPECT_TRUE(secondary_label->GetVisible()); WaitForText(secondary_label, u"Dangerous file"); - EXPECT_EQ(secondary_label->GetEnabledColor(), gfx::kGoogleRed300); + EXPECT_EQ(secondary_label->GetEnabledColor(), + is_dark_mode_state ? gfx::kGoogleRed300 : gfx::kGoogleRed600); // The accessible name should indicate that the download is dangerous. EXPECT_EQ(GetAccessibleName(download_chips.at(0)), @@ -2211,7 +2221,8 @@ EXPECT_EQ(primary_label->GetText(), target_file_name); EXPECT_TRUE(secondary_label->GetVisible()); WaitForText(secondary_label, u"Scanning"); - EXPECT_EQ(secondary_label->GetEnabledColor(), gfx::kGoogleBlue300); + EXPECT_EQ(secondary_label->GetEnabledColor(), + is_dark_mode_state ? gfx::kGoogleBlue300 : gfx::kGoogleBlue600); // The accessible name should indicate that the download is being scanning. EXPECT_EQ(GetAccessibleName(download_chips.at(0)), @@ -2229,7 +2240,8 @@ EXPECT_EQ(primary_label->GetText(), target_file_name); EXPECT_TRUE(secondary_label->GetVisible()); WaitForText(secondary_label, u"Confirm download"); - EXPECT_EQ(secondary_label->GetEnabledColor(), gfx::kGoogleYellow300); + EXPECT_EQ(secondary_label->GetEnabledColor(), + is_dark_mode_state ? gfx::kGoogleYellow300 : gfx::kGoogleYellow600); // The accessible name should indicate that the download must be confirmed. EXPECT_EQ(GetAccessibleName(download_chips.at(0)), @@ -2248,7 +2260,8 @@ EXPECT_EQ(primary_label->GetText(), target_file_name); EXPECT_TRUE(secondary_label->GetVisible()); WaitForText(secondary_label, u"Paused, 2.0/2.0 MB"); - EXPECT_EQ(secondary_label->GetEnabledColor(), gfx::kGoogleGrey400); + EXPECT_EQ(secondary_label->GetEnabledColor(), + is_dark_mode_state ? gfx::kGoogleGrey400 : gfx::kGoogleGrey700); // The accessible name should indicate that the download is in progress and // that progress is paused. @@ -2268,7 +2281,8 @@ EXPECT_EQ(primary_label->GetText(), target_file_name); EXPECT_TRUE(secondary_label->GetVisible()); WaitForText(secondary_label, u"Dangerous file"); - EXPECT_EQ(secondary_label->GetEnabledColor(), gfx::kGoogleRed300); + EXPECT_EQ(secondary_label->GetEnabledColor(), + is_dark_mode_state ? gfx::kGoogleRed300 : gfx::kGoogleRed600); // The accessible name should indicate that the download is dangerous. EXPECT_EQ(GetAccessibleName(download_chips.at(0)), @@ -2287,7 +2301,8 @@ EXPECT_EQ(primary_label->GetText(), target_file_name); EXPECT_TRUE(secondary_label->GetVisible()); WaitForText(secondary_label, u"Paused, 2.0/2.0 MB"); - EXPECT_EQ(secondary_label->GetEnabledColor(), gfx::kGoogleGrey400); + EXPECT_EQ(secondary_label->GetEnabledColor(), + is_dark_mode_state ? gfx::kGoogleGrey400 : gfx::kGoogleGrey700); // The accessible name should indicate that the download is in progress and // that progress is paused. @@ -2301,7 +2316,8 @@ EXPECT_TRUE(primary_label->GetVisible()); EXPECT_EQ(primary_label->GetText(), target_file_name); EXPECT_FALSE(secondary_label->GetVisible()); - EXPECT_EQ(secondary_label->GetEnabledColor(), gfx::kGoogleGrey400); + EXPECT_EQ(secondary_label->GetEnabledColor(), + is_dark_mode_state ? gfx::kGoogleGrey400 : gfx::kGoogleGrey700); // The accessible name should indicate the target file name. EXPECT_EQ(GetAccessibleName(download_chips.at(0)),
diff --git a/chrome/browser/ui/ash/shelf/chrome_shelf_controller_browsertest.cc b/chrome/browser/ui/ash/shelf/chrome_shelf_controller_browsertest.cc index 99ba2dd..84cc069 100644 --- a/chrome/browser/ui/ash/shelf/chrome_shelf_controller_browsertest.cc +++ b/chrome/browser/ui/ash/shelf/chrome_shelf_controller_browsertest.cc
@@ -958,7 +958,8 @@ TestAppWindowIconObserver test_observer(browser()->profile()); int base_shelf_item_count = shelf_model()->item_count(); - ExtensionTestMessageListener ready_listener("ready", true); + ExtensionTestMessageListener ready_listener("ready", + ReplyBehavior::kWillReply); const Extension* extension = LoadAndLaunchPlatformApp("app_icon", "Launched"); ASSERT_TRUE(extension);
diff --git a/chrome/browser/ui/autofill/chrome_autofill_client.cc b/chrome/browser/ui/autofill/chrome_autofill_client.cc index d7616a3a..136917d 100644 --- a/chrome/browser/ui/autofill/chrome_autofill_client.cc +++ b/chrome/browser/ui/autofill/chrome_autofill_client.cc
@@ -18,6 +18,7 @@ #include "chrome/browser/autofill/address_normalizer_factory.h" #include "chrome/browser/autofill/autocomplete_history_manager_factory.h" #include "chrome/browser/autofill/autofill_offer_manager_factory.h" +#include "chrome/browser/autofill/merchant_promo_code_manager_factory.h" #include "chrome/browser/autofill/personal_data_manager_factory.h" #include "chrome/browser/autofill/risk_util.h" #include "chrome/browser/autofill/strike_database_factory.h" @@ -156,6 +157,17 @@ return AutocompleteHistoryManagerFactory::GetForProfile(profile); } +base::WeakPtr<MerchantPromoCodeManager> +ChromeAutofillClient::GetMerchantPromoCodeManager() { + if (!base::FeatureList::IsEnabled( + features::kAutofillFillMerchantPromoCodeFields)) { + return nullptr; + } + Profile* profile = + Profile::FromBrowserContext(web_contents()->GetBrowserContext()); + return MerchantPromoCodeManagerFactory::GetForProfile(profile)->GetWeakPtr(); +} + PrefService* ChromeAutofillClient::GetPrefs() { return const_cast<PrefService*>(base::as_const(*this).GetPrefs()); }
diff --git a/chrome/browser/ui/autofill/chrome_autofill_client.h b/chrome/browser/ui/autofill/chrome_autofill_client.h index ff68e21..1095929 100644 --- a/chrome/browser/ui/autofill/chrome_autofill_client.h +++ b/chrome/browser/ui/autofill/chrome_autofill_client.h
@@ -67,6 +67,8 @@ version_info::Channel GetChannel() const override; PersonalDataManager* GetPersonalDataManager() override; AutocompleteHistoryManager* GetAutocompleteHistoryManager() override; + base::WeakPtr<MerchantPromoCodeManager> GetMerchantPromoCodeManager() + override; PrefService* GetPrefs() override; const PrefService* GetPrefs() const override; syncer::SyncService* GetSyncService() override;
diff --git a/chrome/browser/ui/cocoa/apps/app_shim_menu_controller_mac_browsertest.mm b/chrome/browser/ui/cocoa/apps/app_shim_menu_controller_mac_browsertest.mm index c1b3c42..19cf8e0 100644 --- a/chrome/browser/ui/cocoa/apps/app_shim_menu_controller_mac_browsertest.mm +++ b/chrome/browser/ui/cocoa/apps/app_shim_menu_controller_mac_browsertest.mm
@@ -48,13 +48,13 @@ void SetUpApps(int flags) { if (flags & PACKAGED_1) { - ExtensionTestMessageListener listener_1("Launched", false); + ExtensionTestMessageListener listener_1("Launched"); app_1_ = InstallAndLaunchPlatformApp("minimal_id"); ASSERT_TRUE(listener_1.WaitUntilSatisfied()); } if (flags & PACKAGED_2) { - ExtensionTestMessageListener listener_2("Launched", false); + ExtensionTestMessageListener listener_2("Launched"); app_2_ = InstallAndLaunchPlatformApp("minimal"); ASSERT_TRUE(listener_2.WaitUntilSatisfied()); }
diff --git a/chrome/browser/ui/cocoa/apps/quit_with_apps_controller_mac_interactive_uitest.mm b/chrome/browser/ui/cocoa/apps/quit_with_apps_controller_mac_interactive_uitest.mm index 0a561dd..28590ab 100644 --- a/chrome/browser/ui/cocoa/apps/quit_with_apps_controller_mac_interactive_uitest.mm +++ b/chrome/browser/ui/cocoa/apps/quit_with_apps_controller_mac_interactive_uitest.mm
@@ -72,7 +72,7 @@ QuitWithAppsController::kQuitWithAppsNotificationID)); // Open an app window. - ExtensionTestMessageListener listener("Launched", false); + ExtensionTestMessageListener listener("Launched"); app_ = InstallAndLaunchPlatformApp("minimal_id"); ASSERT_TRUE(listener.WaitUntilSatisfied());
diff --git a/chrome/browser/ui/tab_helpers.cc b/chrome/browser/ui/tab_helpers.cc index fbe2211..47f7747 100644 --- a/chrome/browser/ui/tab_helpers.cc +++ b/chrome/browser/ui/tab_helpers.cc
@@ -128,8 +128,7 @@ #include "components/optimization_guide/core/optimization_guide_features.h" #include "components/page_info/core/features.h" #include "components/password_manager/core/browser/password_manager.h" -#include "components/performance_manager/public/decorators/tab_properties_decorator.h" -#include "components/performance_manager/public/performance_manager.h" +#include "components/performance_manager/embedder/performance_manager_registry.h" #include "components/permissions/features.h" #include "components/permissions/permission_request_manager.h" #include "components/safe_browsing/content/browser/safe_browsing_navigation_observer.h" @@ -373,8 +372,10 @@ } OutOfMemoryReporter::CreateForWebContents(web_contents); chrome::InitializePageLoadMetricsForWebContents(web_contents); - if (performance_manager::PerformanceManager::IsAvailable()) - performance_manager::TabPropertiesDecorator::SetIsTab(web_contents, true); + if (auto* pm_registry = + performance_manager::PerformanceManagerRegistry::GetInstance()) { + pm_registry->SetPageType(web_contents, performance_manager::PageType::kTab); + } permissions::PermissionRequestManager::CreateForWebContents(web_contents); // The PopupBlockerTabHelper has an implicit dependency on // ChromeSubresourceFilterClient being available in its constructor.
diff --git a/chrome/browser/ui/translate/translate_bubble_ui_action_logger.h b/chrome/browser/ui/translate/translate_bubble_ui_action_logger.h index c1f72be..36bb7514 100644 --- a/chrome/browser/ui/translate/translate_bubble_ui_action_logger.h +++ b/chrome/browser/ui/translate/translate_bubble_ui_action_logger.h
@@ -62,10 +62,10 @@ TARGET_LANGUAGE_MENU_CLICKED = 17, // The user activated the translate page action icon. - PAGE_ACTION_ICON_ACTIVATED = 18, + // [DEPRECATED] PAGE_ACTION_ICON_ACTIVATED = 18, // The user deactivated the translate page action icon. - PAGE_ACTION_ICON_DEACTIVATED = 19, + // [DEPRECATED] PAGE_ACTION_ICON_DEACTIVATED = 19, // The translate bubble was shown to the user. BUBBLE_SHOWN = 20,
diff --git a/chrome/browser/ui/views/apps/chrome_native_app_window_views_aura_ash_browsertest.cc b/chrome/browser/ui/views/apps/chrome_native_app_window_views_aura_ash_browsertest.cc index 57c94ef..5859d7db 100644 --- a/chrome/browser/ui/views/apps/chrome_native_app_window_views_aura_ash_browsertest.cc +++ b/chrome/browser/ui/views/apps/chrome_native_app_window_views_aura_ash_browsertest.cc
@@ -86,7 +86,8 @@ std::unique_ptr<ExtensionTestMessageListener> LaunchPlatformAppWithFocusedWindow() { std::unique_ptr<ExtensionTestMessageListener> launched_listener = - std::make_unique<ExtensionTestMessageListener>("Launched", true); + std::make_unique<ExtensionTestMessageListener>( + "Launched", ReplyBehavior::kWillReply); LoadAndLaunchPlatformApp("leave_fullscreen", launched_listener.get()); // We start by making sure the window is actually focused.
diff --git a/chrome/browser/ui/views/commerce/ntp_discount_consent_dialog_view.cc b/chrome/browser/ui/views/commerce/ntp_discount_consent_dialog_view.cc index e844b50..89c637d5 100644 --- a/chrome/browser/ui/views/commerce/ntp_discount_consent_dialog_view.cc +++ b/chrome/browser/ui/views/commerce/ntp_discount_consent_dialog_view.cc
@@ -48,6 +48,12 @@ SetOwnedByWidget(true); // TODO(meiliang@): Set text for the button. SetButtons(ui::DIALOG_BUTTON_CANCEL | ui::DIALOG_BUTTON_OK); + SetButtonLabel( + ui::DIALOG_BUTTON_CANCEL, + l10n_util::GetStringUTF16(IDS_DISCOUNT_CONTEXTUAL_CONSENT_NO_THANKS)); + SetButtonLabel(ui::DIALOG_BUTTON_OK, + l10n_util::GetStringUTF16( + IDS_NATIVE_NTP_CART_DISCOUNT_CONSENT_ACCEPT_BUTTON)); set_fixed_width(views::LayoutProvider::Get()->GetDistanceMetric( DISTANCE_LARGE_MODAL_DIALOG_PREFERRED_WIDTH));
diff --git a/chrome/browser/ui/views/extensions/extension_dialog_bounds_browsertest.cc b/chrome/browser/ui/views/extensions/extension_dialog_bounds_browsertest.cc index 693a9f4..658e8940 100644 --- a/chrome/browser/ui/views/extensions/extension_dialog_bounds_browsertest.cc +++ b/chrome/browser/ui/views/extensions/extension_dialog_bounds_browsertest.cc
@@ -58,7 +58,7 @@ void ShowOpenFileDialog() { browser()->OpenFile(); } void ShowBigExtensionDialog() { - ExtensionTestMessageListener init_listener("ready", false /* will_reply */); + ExtensionTestMessageListener init_listener("ready"); scoped_refptr<const extensions::Extension> extension = LoadExtension(test_data_dir_.AppendASCII("uitest/tab_traversal"));
diff --git a/chrome/browser/ui/views/extensions/extension_dialog_browsertest.cc b/chrome/browser/ui/views/extensions/extension_dialog_browsertest.cc index 453281e..33f1a27 100644 --- a/chrome/browser/ui/views/extensions/extension_dialog_browsertest.cc +++ b/chrome/browser/ui/views/extensions/extension_dialog_browsertest.cc
@@ -16,7 +16,7 @@ using ExtensionDialogTest = extensions::ExtensionBrowserTest; IN_PROC_BROWSER_TEST_F(ExtensionDialogTest, TextInputViaKeyEvent) { - ExtensionTestMessageListener init_listener("ready", /*will_reply=*/false); + ExtensionTestMessageListener init_listener("ready"); scoped_refptr<const extensions::Extension> extension = LoadExtension(test_data_dir_.AppendASCII("uitest/tab_traversal"));
diff --git a/chrome/browser/ui/views/extensions/extension_dialog_interactive_uitest.cc b/chrome/browser/ui/views/extensions/extension_dialog_interactive_uitest.cc index 7653a93..43a89d2 100644 --- a/chrome/browser/ui/views/extensions/extension_dialog_interactive_uitest.cc +++ b/chrome/browser/ui/views/extensions/extension_dialog_interactive_uitest.cc
@@ -34,10 +34,10 @@ #define MAYBE_TabFocusLoop TabFocusLoop #endif IN_PROC_BROWSER_TEST_F(ExtensionDialogUiTest, MAYBE_TabFocusLoop) { - ExtensionTestMessageListener init_listener("ready", false /* will_reply */); - ExtensionTestMessageListener button1_focus_listener("button1-focused", false); - ExtensionTestMessageListener button2_focus_listener("button2-focused", false); - ExtensionTestMessageListener button3_focus_listener("button3-focused", false); + ExtensionTestMessageListener init_listener("ready"); + ExtensionTestMessageListener button1_focus_listener("button1-focused"); + ExtensionTestMessageListener button2_focus_listener("button2-focused"); + ExtensionTestMessageListener button3_focus_listener("button3-focused"); // Load an extension for the test. scoped_refptr<const extensions::Extension> extension =
diff --git a/chrome/browser/ui/views/extensions/extensions_toolbar_container_interactive_uitest.cc b/chrome/browser/ui/views/extensions/extensions_toolbar_container_interactive_uitest.cc index 0d13af2..5236842 100644 --- a/chrome/browser/ui/views/extensions/extensions_toolbar_container_interactive_uitest.cc +++ b/chrome/browser/ui/views/extensions/extensions_toolbar_container_interactive_uitest.cc
@@ -308,8 +308,7 @@ { // Click on Alpha and wait for it to open the popup. - ExtensionTestMessageListener listener("alpha popup opened", - /*will_reply=*/false); + ExtensionTestMessageListener listener("alpha popup opened"); ClickOnAction(alpha_action); EXPECT_TRUE(listener.WaitUntilSatisfied()); } @@ -330,8 +329,7 @@ *process_manager->GetRenderFrameHostsForExtension(alpha->id()).begin(); content::WebContentsDestroyedWatcher popup_destroyed( content::WebContents::FromRenderFrameHost(popup_frame)); - ExtensionTestMessageListener listener("beta popup opened", - /*will_reply=*/false); + ExtensionTestMessageListener listener("beta popup opened"); ClickOnAction(beta_action); EXPECT_TRUE(listener.WaitUntilSatisfied()); popup_destroyed.Wait(); @@ -369,7 +367,7 @@ container->GetViewForId(extension->id()); EXPECT_TRUE(action_view->GetVisible()); - ExtensionTestMessageListener listener("Popup opened", /*will_reply=*/false); + ExtensionTestMessageListener listener("Popup opened"); EXPECT_TRUE(ui_test_utils::SendMouseMoveSync( ui_test_utils::GetCenterInScreenCoordinates(action_view))); EXPECT_TRUE(ui_controls::SendMouseClick(ui_controls::LEFT)); @@ -715,8 +713,7 @@ l10n_util::GetStringUTF16(IDS_EXTENSIONS_HAS_ACCESS_TO_SITE)}, u"\n"); - ExtensionTestMessageListener injection_listener(kInjectionSucceededMessage, - false /* will_reply */); + ExtensionTestMessageListener injection_listener(kInjectionSucceededMessage); injection_listener.set_extension_id(extension()->id()); GURL url = embedded_test_server()->GetURL("example.com", "/title1.html"); @@ -807,8 +804,7 @@ l10n_util::GetStringUTF16(IDS_EXTENSIONS_HAS_ACCESS_TO_SITE)}, u"\n"); - ExtensionTestMessageListener injection_listener(kInjectionSucceededMessage, - false /* will_reply */); + ExtensionTestMessageListener injection_listener(kInjectionSucceededMessage); injection_listener.set_extension_id(extension()->id()); GURL url = embedded_test_server()->GetURL("example.com", "/title1.html");
diff --git a/chrome/browser/ui/views/frame/browser_non_client_frame_view_chromeos_browsertest.cc b/chrome/browser/ui/views/frame/browser_non_client_frame_view_chromeos_browsertest.cc index 8774a0ee..672ee45 100644 --- a/chrome/browser/ui/views/frame/browser_non_client_frame_view_chromeos_browsertest.cc +++ b/chrome/browser/ui/views/frame/browser_non_client_frame_view_chromeos_browsertest.cc
@@ -22,6 +22,7 @@ #include "chrome/browser/ui/views/frame/immersive_mode_controller.h" #include "chrome/browser/ui/views/frame/immersive_mode_tester.h" #include "chrome/test/base/in_process_browser_test.h" +#include "chromeos/constants/chromeos_features.h" #include "chromeos/ui/base/window_properties.h" #include "chromeos/ui/frame/caption_buttons/frame_caption_button_container_view.h" #include "components/keep_alive_registry/keep_alive_types.h" @@ -1205,10 +1206,26 @@ SkColor active_frame_color = window->GetProperty(chromeos::kFrameActiveColorKey); - EXPECT_EQ(active_frame_color, SkColorSetRGB(0xFD, 0xFE, 0xFF)) - << "RGB: " << SkColorGetR(active_frame_color) << ", " - << SkColorGetG(active_frame_color) << ", " - << SkColorGetB(active_frame_color); + + if (!chromeos::features::IsDarkLightModeEnabled()) { + // `kDefaultFrameColor` will only be used when dark/light mode feature is + // not enabled. + EXPECT_EQ(active_frame_color, SkColorSetRGB(0xFD, 0xFE, 0xFF)) + << "RGB: " << SkColorGetR(active_frame_color) << ", " + << SkColorGetG(active_frame_color) << ", " + << SkColorGetB(active_frame_color); + } else { + const bool is_dark_mode_state = + BrowserView::GetBrowserViewForBrowser(browser()) + ->GetNativeTheme() + ->ShouldUseDarkColors(); + EXPECT_EQ(active_frame_color, is_dark_mode_state + ? gfx::kGoogleGrey900 + : SkColorSetRGB(0xFF, 0xFF, 0xFF)) + << "RGB: " << SkColorGetR(active_frame_color) << ", " + << SkColorGetG(active_frame_color) << ", " + << SkColorGetB(active_frame_color); + } } #if BUILDFLAG(IS_CHROMEOS_ASH)
diff --git a/chrome/browser/ui/views/select_file_dialog_extension_browsertest.cc b/chrome/browser/ui/views/select_file_dialog_extension_browsertest.cc index 270258c8..767a28b 100644 --- a/chrome/browser/ui/views/select_file_dialog_extension_browsertest.cc +++ b/chrome/browser/ui/views/select_file_dialog_extension_browsertest.cc
@@ -297,13 +297,12 @@ } // Open the file dialog: Files app will signal that it is loaded via the // "ready" chrome.test.sendMessage(). - const bool will_reply = false; - ExtensionTestMessageListener init_listener("ready", will_reply); + ExtensionTestMessageListener init_listener("ready"); std::unique_ptr<ExtensionTestMessageListener> additional_listener; if (!additional_message.empty()) { - additional_listener = std::make_unique<ExtensionTestMessageListener>( - additional_message, will_reply); + additional_listener = + std::make_unique<ExtensionTestMessageListener>(additional_message); } std::u16string title; @@ -619,7 +618,7 @@ ASSERT_EQ(url, web_contents->GetLastCommittedURL()); // Create a listener for the file dialog's "ready" message. - ExtensionTestMessageListener listener("ready", false); + ExtensionTestMessageListener listener("ready"); // Click the file <input> element to open the file dialog. constexpr auto kButton = blink::WebMouseEvent::Button::kLeft;
diff --git a/chrome/browser/ui/views/send_tab_to_self/send_tab_to_self_device_picker_bubble_view.cc b/chrome/browser/ui/views/send_tab_to_self/send_tab_to_self_device_picker_bubble_view.cc index 7c4b7ee2..0e36cfd 100644 --- a/chrome/browser/ui/views/send_tab_to_self/send_tab_to_self_device_picker_bubble_view.cc +++ b/chrome/browser/ui/views/send_tab_to_self/send_tab_to_self_device_picker_bubble_view.cc
@@ -158,26 +158,20 @@ } void SendTabToSelfDevicePickerBubbleView::CreateHintTextLabel() { - views::View* container = AddChildView(std::make_unique<views::View>()); auto* provider = ChromeLayoutProvider::Get(); - container->SetProperty( - views::kMarginsKey, - gfx::Insets::TLBR(0, 0, - provider->GetDistanceMetric( - views::DISTANCE_UNRELATED_CONTROL_VERTICAL), - 0)); - auto* container_layout = - container->SetLayoutManager(std::make_unique<views::BoxLayout>( - views::BoxLayout::Orientation::kVertical, - gfx::Insets::VH(0, provider->GetDistanceMetric( - views::DISTANCE_BUTTON_HORIZONTAL_PADDING)))); - container_layout->set_cross_axis_alignment( - views::BoxLayout::CrossAxisAlignment::kCenter); - - auto* description = container->AddChildView(std::make_unique<views::Label>( + auto* description = AddChildView(std::make_unique<views::Label>( l10n_util::GetStringUTF16( IDS_TOOLBAR_BUTTON_SEND_TAB_TO_SELF_BUTTON_HINT_TEXT), views::style::CONTEXT_LABEL, views::style::STYLE_SECONDARY)); + description->SetProperty( + views::kMarginsKey, + gfx::Insets::TLBR(0, + provider->GetDistanceMetric( + views::DISTANCE_BUTTON_HORIZONTAL_PADDING), + provider->GetDistanceMetric( + views::DISTANCE_UNRELATED_CONTROL_VERTICAL), + provider->GetDistanceMetric( + views::DISTANCE_BUTTON_HORIZONTAL_PADDING))); description->SetMultiLine(true); description->SetHorizontalAlignment(gfx::ALIGN_LEFT); }
diff --git a/chrome/browser/ui/views/translate/translate_icon_view.cc b/chrome/browser/ui/views/translate/translate_icon_view.cc index a5a03d5..dfe1f96d 100644 --- a/chrome/browser/ui/views/translate/translate_icon_view.cc +++ b/chrome/browser/ui/views/translate/translate_icon_view.cc
@@ -71,12 +71,6 @@ void TranslateIconView::OnExecuting( PageActionIconView::ExecuteSource execute_source) {} -void TranslateIconView::OnPressed(bool activated) { - translate::ReportTranslateBubbleUiAction( - activated ? translate::PAGE_ACTION_ICON_ACTIVATED - : translate::PAGE_ACTION_ICON_DEACTIVATED); -} - const gfx::VectorIcon& TranslateIconView::GetVectorIcon() const { return kTranslateIcon; }
diff --git a/chrome/browser/ui/views/translate/translate_icon_view.h b/chrome/browser/ui/views/translate/translate_icon_view.h index d081afcf..9fcd017 100644 --- a/chrome/browser/ui/views/translate/translate_icon_view.h +++ b/chrome/browser/ui/views/translate/translate_icon_view.h
@@ -29,7 +29,6 @@ protected: // PageActionIconView: void OnExecuting(PageActionIconView::ExecuteSource execute_source) override; - void OnPressed(bool activated) override; const gfx::VectorIcon& GetVectorIcon() const override; std::u16string GetTextForTooltipAndAccessibleName() const override; };
diff --git a/chrome/browser/ui/web_applications/web_app_browser_controller.cc b/chrome/browser/ui/web_applications/web_app_browser_controller.cc index b6c86c9..fbf8c1c 100644 --- a/chrome/browser/ui/web_applications/web_app_browser_controller.cc +++ b/chrome/browser/ui/web_applications/web_app_browser_controller.cc
@@ -42,7 +42,14 @@ #if BUILDFLAG(IS_CHROMEOS_ASH) #include "chrome/browser/ash/apps/apk_web_app_service.h" +#endif +#if BUILDFLAG(IS_CHROMEOS_LACROS) +#include "chromeos/crosapi/mojom/web_app_service.mojom.h" +#include "chromeos/lacros/lacros_service.h" +#endif + +#if BUILDFLAG(IS_CHROMEOS) namespace { constexpr char kRelationship[] = "delegate_permission/common.handle_all_urls"; } @@ -149,7 +156,7 @@ return system_app_; } -#if BUILDFLAG(IS_CHROMEOS_ASH) +#if BUILDFLAG(IS_CHROMEOS) bool WebAppBrowserController::ShouldShowCustomTabBar() const { if (AppBrowserController::ShouldShowCustomTabBar()) return true; @@ -157,6 +164,18 @@ return is_verified_.value_or(false); } +void WebAppBrowserController::CheckDigitalAssetLinkRelationshipForAndroidApp( + const std::string& package_name, + const std::string& fingerprint) { + // base::Unretained is safe as |asset_link_handler_| is owned by this object + // and will be destroyed if this object is destroyed. + const std::string origin = GetAppStartUrl().DeprecatedGetOriginAsURL().spec(); + asset_link_handler_->CheckDigitalAssetLinkRelationshipForAndroidApp( + origin, kRelationship, fingerprint, package_name, + base::BindOnce(&WebAppBrowserController::OnRelationshipCheckComplete, + base::Unretained(this))); +} + void WebAppBrowserController::OnRelationshipCheckComplete( digital_asset_links::RelationshipCheckResult result) { bool should_show_cct = false; @@ -173,7 +192,19 @@ browser()->window()->UpdateCustomTabBarVisibility(should_show_cct, false /* animate */); } -#endif // BUILDFLAG(IS_CHROMEOS_ASH) +#endif // BUILDFLAG(IS_CHROMEOS) + +#if BUILDFLAG(IS_CHROMEOS_LACROS) +void WebAppBrowserController::OnGetAssociatedAndroidPackage( + crosapi::mojom::WebAppAndroidPackagePtr package) { + if (!package) { + // Web app was not installed from an Android package, nothing to check. + return; + } + CheckDigitalAssetLinkRelationshipForAndroidApp(package->package_name, + package->sha256_fingerprint); +} +#endif // BUILDFLAG(IS_CHROMEOS_LACROS) void WebAppBrowserController::OnWebAppUninstalled( const AppId& uninstalled_app_id) { @@ -413,18 +444,19 @@ void WebAppBrowserController::PerformDigitalAssetLinkVerification( Browser* browser) { -#if BUILDFLAG(IS_CHROMEOS_ASH) +#if BUILDFLAG(IS_CHROMEOS) asset_link_handler_ = std::make_unique<digital_asset_links::DigitalAssetLinksHandler>( browser->profile()->GetURLLoaderFactory()); is_verified_ = absl::nullopt; +#endif +#if BUILDFLAG(IS_CHROMEOS_ASH) ash::ApkWebAppService* apk_web_app_service = ash::ApkWebAppService::Get(browser->profile()); if (!apk_web_app_service || !apk_web_app_service->IsWebOnlyTwa(app_id())) return; - const std::string origin = GetAppStartUrl().DeprecatedGetOriginAsURL().spec(); const absl::optional<std::string> package_name = apk_web_app_service->GetPackageNameForWebApp(app_id()); const absl::optional<std::string> fingerprint = @@ -434,12 +466,24 @@ DCHECK(package_name.has_value()); DCHECK(fingerprint.has_value()); - // base::Unretained is safe as |asset_link_handler_| is owned by this object - // and will be destroyed if this object is destroyed. - asset_link_handler_->CheckDigitalAssetLinkRelationshipForAndroidApp( - origin, kRelationship, fingerprint.value(), package_name.value(), - base::BindOnce(&WebAppBrowserController::OnRelationshipCheckComplete, - base::Unretained(this))); + CheckDigitalAssetLinkRelationshipForAndroidApp(*package_name, *fingerprint); +#endif + +#if BUILDFLAG(IS_CHROMEOS_LACROS) + auto* lacros_service = chromeos::LacrosService::Get(); + if (lacros_service && lacros_service->init_params()->web_apps_enabled && + lacros_service->IsAvailable<crosapi::mojom::WebAppService>() && + lacros_service->GetInterfaceVersion( + crosapi::mojom::WebAppService::Uuid_) >= + int{crosapi::mojom::WebAppService::MethodMinVersions:: + kGetAssociatedAndroidPackageMinVersion}) { + lacros_service->GetRemote<crosapi::mojom::WebAppService>() + ->GetAssociatedAndroidPackage( + app_id(), + base::BindOnce( + &WebAppBrowserController::OnGetAssociatedAndroidPackage, + weak_ptr_factory_.GetWeakPtr())); + } #endif }
diff --git a/chrome/browser/ui/web_applications/web_app_browser_controller.h b/chrome/browser/ui/web_applications/web_app_browser_controller.h index 56900d1..c6f86e1e 100644 --- a/chrome/browser/ui/web_applications/web_app_browser_controller.h +++ b/chrome/browser/ui/web_applications/web_app_browser_controller.h
@@ -24,10 +24,14 @@ #include "third_party/skia/include/core/SkColor.h" #include "ui/base/models/image_model.h" -#if BUILDFLAG(IS_CHROMEOS_ASH) +#if BUILDFLAG(IS_CHROMEOS) #include "components/digital_asset_links/digital_asset_links_handler.h" // nogncheck #endif +#if BUILDFLAG(IS_CHROMEOS_LACROS) +#include "chromeos/crosapi/mojom/web_app_service.mojom-forward.h" +#endif + class Browser; class SkBitmap; @@ -87,7 +91,7 @@ bool HasReloadButton() const override; const ash::SystemWebAppDelegate* system_app() const override; -#if BUILDFLAG(IS_CHROMEOS_ASH) +#if BUILDFLAG(IS_CHROMEOS) bool ShouldShowCustomTabBar() const override; #endif // BUILDFLAG(IS_CHROMEOS_ASH) @@ -114,10 +118,17 @@ void OnReadIcon(SkBitmap bitmap); void PerformDigitalAssetLinkVerification(Browser* browser); -#if BUILDFLAG(IS_CHROMEOS_ASH) +#if BUILDFLAG(IS_CHROMEOS) + void CheckDigitalAssetLinkRelationshipForAndroidApp( + const std::string& package_name, + const std::string& fingerprint); void OnRelationshipCheckComplete( digital_asset_links::RelationshipCheckResult result); -#endif // BUILDFLAG(IS_CHROMEOS_ASH) +#endif // BUILDFLAG(IS_CHROMEOS) + +#if BUILDFLAG(IS_CHROMEOS_LACROS) + void OnGetAssociatedAndroidPackage(crosapi::mojom::WebAppAndroidPackagePtr); +#endif // BUILDFLAG(IS_CHROMEOS_LACROS) // Helper function to return the resolved background color from the manifest // given the current state of dark/light mode. @@ -127,14 +138,14 @@ raw_ptr<const ash::SystemWebAppDelegate> system_app_; mutable absl::optional<ui::ImageModel> app_icon_; -#if BUILDFLAG(IS_CHROMEOS_ASH) +#if BUILDFLAG(IS_CHROMEOS) // The result of digital asset link verification of the web app. // Only used for web-only TWAs installed through the Play Store. absl::optional<bool> is_verified_; std::unique_ptr<digital_asset_links::DigitalAssetLinksHandler> asset_link_handler_; -#endif // BUILDFLAG(IS_CHROMEOS_ASH) +#endif // BUILDFLAG(IS_CHROMEOS) base::ScopedObservation<WebAppInstallManager, WebAppInstallManagerObserver> install_manager_observation_{this};
diff --git a/chrome/browser/ui/webui/discards/graph_dump_impl.h b/chrome/browser/ui/webui/discards/graph_dump_impl.h index 2ef0d9e8..fee0bbef 100644 --- a/chrome/browser/ui/webui/discards/graph_dump_impl.h +++ b/chrome/browser/ui/webui/discards/graph_dump_impl.h
@@ -127,6 +127,8 @@ const performance_manager::FrameNode* previous_embedder, EmbeddingType previous_embedding_type) override; // Ignored. + void OnTypeChanged(const performance_manager::PageNode* page_node) override {} + // Ignored. void OnIsVisibleChanged( const performance_manager::PageNode* page_node) override {} // Ignored.
diff --git a/chrome/browser/ui/webui/extensions/extension_settings_browsertest.cc b/chrome/browser/ui/webui/extensions/extension_settings_browsertest.cc index 6f4118cd..3ef8aa9 100644 --- a/chrome/browser/ui/webui/extensions/extension_settings_browsertest.cc +++ b/chrome/browser/ui/webui/extensions/extension_settings_browsertest.cc
@@ -262,7 +262,7 @@ test_data_dir = test_data_dir.AppendASCII("extensions"); extensions::ChromeTestExtensionLoader loader(browser()->profile()); - ExtensionTestMessageListener listener("ready", false); + ExtensionTestMessageListener listener("ready"); scoped_refptr<const extensions::Extension> extension = loader.LoadExtension( test_data_dir.AppendASCII("activity_log/simple_call")); ASSERT_TRUE(listener.WaitUntilSatisfied());
diff --git a/chrome/build/win32.pgo.txt b/chrome/build/win32.pgo.txt index cfc0dc82..bbb1101 100644 --- a/chrome/build/win32.pgo.txt +++ b/chrome/build/win32.pgo.txt
@@ -1 +1 @@ -chrome-win32-main-1653490600-386d2c3059c805b595be5e6c30689e8bb63026ef.profdata +chrome-win32-main-1653512386-25625488509f0c602fc9fa23bce35f754f412647.profdata
diff --git a/chrome/build/win64.pgo.txt b/chrome/build/win64.pgo.txt index e18af06..24817fc 100644 --- a/chrome/build/win64.pgo.txt +++ b/chrome/build/win64.pgo.txt
@@ -1 +1 @@ -chrome-win64-main-1653490600-6ccdd3081a7f2a1d84f9a7aaefbfadb7e07fe34b.profdata +chrome-win64-main-1653512386-3f39d7d1ec1bd1288d100175a1323fe1074c11e7.profdata
diff --git a/chrome/test/BUILD.gn b/chrome/test/BUILD.gn index e2db799d..efc4c19 100644 --- a/chrome/test/BUILD.gn +++ b/chrome/test/BUILD.gn
@@ -1899,8 +1899,8 @@ "../browser/performance_manager/mechanisms/page_discarder_browsertest.cc", "../browser/performance_manager/observers/page_load_metrics_observer_browsertest.cc", "../browser/performance_manager/page_load_tracker_decorator_browsertest.cc", + "../browser/performance_manager/page_node_browsertest.cc", "../browser/performance_manager/policies/bfcache_policy_browsertest.cc", - "../browser/performance_manager/tab_properties_decorator_browsertest.cc", "../browser/permissions/permission_delegation_browsertest.cc", "../browser/permissions/permission_manager_browsertest.cc", "../browser/permissions/permission_request_manager_browsertest.cc", @@ -2409,6 +2409,7 @@ if (is_chromeos) { deps += [ "//chrome/browser/web_applications/app_service", + "//chromeos/constants", "//chromeos/ui/frame", "//components/account_manager_core:test_support", ]
diff --git a/chrome/test/android/browsertests_apk/android_browsertests_jni_onload.cc b/chrome/test/android/browsertests_apk/android_browsertests_jni_onload.cc index 302c558..7d21e30 100644 --- a/chrome/test/android/browsertests_apk/android_browsertests_jni_onload.cc +++ b/chrome/test/android/browsertests_apk/android_browsertests_jni_onload.cc
@@ -5,11 +5,44 @@ #include <memory> #include "base/android/jni_android.h" +#include "base/android/jni_utils.h" +#include "base/android/library_loader/library_loader_hooks.h" +#include "base/bind.h" +#include "base/command_line.h" +#include "base/no_destructor.h" #include "chrome/app/android/chrome_jni_onload.h" #include "chrome/test/base/chrome_test_launcher.h" +#include "chrome/utility/chrome_content_utility_client.h" #include "content/public/app/content_jni_onload.h" #include "content/public/app/content_main.h" +#include "content/public/common/content_switches.h" #include "content/public/test/nested_message_pump_android.h" +#include "content/public/test/network_service_test_helper.h" +#include "services/network/public/mojom/network_service.mojom.h" + +namespace { +bool NativeInit(base::android::LibraryProcessType) { + static base::NoDestructor<content::NetworkServiceTestHelper> + network_service_test_helper; + + // Setup a working test environment for the network service in case it's used. + // Only create this object in the utility process, so that its members don't + // interfere with other test objects in the browser process. + base::CommandLine* command_line = base::CommandLine::ForCurrentProcess(); + if (command_line->GetSwitchValueASCII(switches::kProcessType) == + switches::kUtilityProcess && + command_line->GetSwitchValueASCII(switches::kUtilitySubType) == + network::mojom::NetworkService::Name_) { + ChromeContentUtilityClient::SetNetworkBinderCreationCallback(base::BindOnce( + [](content::NetworkServiceTestHelper* helper, + service_manager::BinderRegistry* registry) { + helper->RegisterNetworkBinders(registry); + }, + network_service_test_helper.get())); + } + return true; +} +} // namespace // This is called by the VM when the shared library is first loaded. JNI_EXPORT jint JNI_OnLoad(JavaVM* vm, void* reserved) { @@ -29,6 +62,7 @@ }); content::SetContentMainDelegate(new ChromeTestChromeMainDelegate()); + base::android::SetNativeInitializationHook(NativeInit); return JNI_VERSION_1_4; }
diff --git a/chrome/test/data/client_hints/partitioned_cookies_embeddee.html b/chrome/test/data/client_hints/partitioned_cookies_embeddee.html deleted file mode 100644 index a9213b9..0000000 --- a/chrome/test/data/client_hints/partitioned_cookies_embeddee.html +++ /dev/null
@@ -1,6 +0,0 @@ -<html> -<link rel="icon" href="data:;base64,="> -<head> -</head> -Empty file used for embedded partitioned cookies Origin Trial testing. -</html>
diff --git a/chrome/test/data/client_hints/partitioned_cookies_same_origin.html b/chrome/test/data/client_hints/partitioned_cookies_same_origin.html deleted file mode 100644 index 9add677..0000000 --- a/chrome/test/data/client_hints/partitioned_cookies_same_origin.html +++ /dev/null
@@ -1,6 +0,0 @@ -<html> -<link rel="icon" href="data:;base64,="> -<head> -</head> -Empty file used for same-origin partitioned cookies Origin Trial testing. -</html>
diff --git a/chrome/test/data/webui/chromeos/personalization_app/keyboard_backlight_element_test.ts b/chrome/test/data/webui/chromeos/personalization_app/keyboard_backlight_element_test.ts index 4d25344..b656190 100644 --- a/chrome/test/data/webui/chromeos/personalization_app/keyboard_backlight_element_test.ts +++ b/chrome/test/data/webui/chromeos/personalization_app/keyboard_backlight_element_test.ts
@@ -5,8 +5,8 @@ import 'chrome://personalization/strings.m.js'; import 'chrome://webui-test/mojo_webui_test_support.js'; -import {KeyboardBacklight, KeyboardBacklightActionName, KeyboardBacklightObserver, SetBacklightColorAction} from 'chrome://personalization/trusted/personalization_app.js'; -import {assertEquals, assertTrue} from 'chrome://webui-test/chai_assert.js'; +import {KeyboardBacklight, KeyboardBacklightActionName, KeyboardBacklightObserver, SetBacklightColorAction, SetWallpaperColorAction} from 'chrome://personalization/trusted/personalization_app.js'; +import {assertDeepEquals, assertEquals, assertTrue} from 'chrome://webui-test/chai_assert.js'; import {baseSetup, initElement, teardownElement} from './personalization_app_test_utils.js'; import {TestKeyboardBacklightProvider} from './test_keyboard_backlight_interface_provider.js'; @@ -100,4 +100,17 @@ SetBacklightColorAction; assertEquals(keyboardBacklightProvider.backlightColor, backlightColor); }); + + test('sets wallpaper color in store on first load', async () => { + personalizationStore.expectAction( + KeyboardBacklightActionName.SET_WALLPAPER_COLOR); + keyboardBacklightElement = initElement(KeyboardBacklight); + await keyboardBacklightProvider.whenCalled('setKeyboardBacklightObserver'); + const wallpaperColor = {value: 0x123456}; + keyboardBacklightProvider.fireOnWallpaperColorChanged(wallpaperColor); + const action = await personalizationStore.waitForAction( + KeyboardBacklightActionName.SET_WALLPAPER_COLOR) as + SetWallpaperColorAction; + assertDeepEquals(wallpaperColor, action.wallpaperColor); + }); });
diff --git a/chrome/test/data/webui/chromeos/personalization_app/test_keyboard_backlight_interface_provider.ts b/chrome/test/data/webui/chromeos/personalization_app/test_keyboard_backlight_interface_provider.ts index c3e7db12..6d66129 100644 --- a/chrome/test/data/webui/chromeos/personalization_app/test_keyboard_backlight_interface_provider.ts +++ b/chrome/test/data/webui/chromeos/personalization_app/test_keyboard_backlight_interface_provider.ts
@@ -3,6 +3,7 @@ // found in the LICENSE file. import {BacklightColor, KeyboardBacklightObserverInterface, KeyboardBacklightObserverRemote, KeyboardBacklightProviderInterface} from 'chrome://personalization/trusted/personalization_app.js'; +import {SkColor} from 'chrome://resources/mojo/skia/public/mojom/skcolor.mojom-webui.js'; import {TestBrowserProxy} from 'chrome://webui-test/test_browser_proxy.js'; export class TestKeyboardBacklightProvider extends @@ -33,4 +34,9 @@ this.keyboardBacklightObserverRemote!.onBacklightColorChanged( backlightColor); } + + fireOnWallpaperColorChanged(wallpaperColor: SkColor) { + this.keyboardBacklightObserverRemote!.onWallpaperColorChanged( + wallpaperColor); + } }
diff --git a/chrome/test/data/webui/chromeos/shimless_rma/fake_shimless_rma_service_test.js b/chrome/test/data/webui/chromeos/shimless_rma/fake_shimless_rma_service_test.js index 1b4ac5f..d092110 100644 --- a/chrome/test/data/webui/chromeos/shimless_rma/fake_shimless_rma_service_test.js +++ b/chrome/test/data/webui/chromeos/shimless_rma/fake_shimless_rma_service_test.js
@@ -32,7 +32,7 @@ const states = [ { state: State.kWelcomeScreen, - canCancel: true, + canExit: true, canGoBack: false, error: RmadErrorCode.kOk }, @@ -41,7 +41,7 @@ return service.getCurrentState().then((state) => { assertEquals(state.state, State.kWelcomeScreen); - assertTrue(state.canCancel); + assertTrue(state.canExit); assertFalse(state.canGoBack); assertEquals(state.error, RmadErrorCode.kOk); });
diff --git a/chrome/test/data/webui/chromeos/shimless_rma/shimless_rma_app_test.js b/chrome/test/data/webui/chromeos/shimless_rma/shimless_rma_app_test.js index 8b93d1f..7cdb874 100644 --- a/chrome/test/data/webui/chromeos/shimless_rma/shimless_rma_app_test.js +++ b/chrome/test/data/webui/chromeos/shimless_rma/shimless_rma_app_test.js
@@ -232,7 +232,7 @@ await initializeShimlessRMAApp( [{ state: State.kSelectComponents, - canCancel: true, + canExit: true, canGoBack: true, error: RmadErrorCode.kOk }], @@ -256,7 +256,7 @@ await initializeShimlessRMAApp( [{ state: State.kSelectComponents, - canCancel: true, + canExit: true, canGoBack: true, error: RmadErrorCode.kOk }], @@ -297,7 +297,7 @@ await initializeShimlessRMAApp( [{ state: State.kSelectComponents, - canCancel: true, + canExit: true, canGoBack: true, error: RmadErrorCode.kOk }], @@ -335,7 +335,7 @@ await initializeShimlessRMAApp( [{ state: State.kSelectComponents, - canCancel: true, + canExit: true, canGoBack: true, error: RmadErrorCode.kOk }], @@ -404,7 +404,7 @@ await initializeShimlessRMAApp( [{ state: State.kWelcomeScreen, - canCancel: true, + canExit: true, canGoBack: true, error: RmadErrorCode.kOk }],
diff --git a/chromecast/media/cma/backend/android/audio_sink_android_audiotrack_impl.cc b/chromecast/media/cma/backend/android/audio_sink_android_audiotrack_impl.cc index ecfbb3a..2bec137e 100644 --- a/chromecast/media/cma/backend/android/audio_sink_android_audiotrack_impl.cc +++ b/chromecast/media/cma/backend/android/audio_sink_android_audiotrack_impl.cc
@@ -315,9 +315,12 @@ int bytes_per_frame = num_channels_ * (use_hw_av_sync_ ? sizeof(int16_t) : sizeof(float)); int64_t fed_frames = pending_data_bytes_already_fed_ / bytes_per_frame; - int64_t timestamp_ns_new = pending_data_->timestamp() + - fed_frames * base::Time::kNanosecondsPerSecond / - input_samples_per_second_; + int64_t timestamp_ns_new = + (pending_data_->timestamp() == INT64_MIN) + ? pending_data_->timestamp() + : pending_data_->timestamp() + fed_frames * + base::Time::kNanosecondsPerSecond / + input_samples_per_second_; int written = Java_AudioSinkAudioTrackImpl_writePcm( base::android::AttachCurrentThread(), j_audio_sink_audiotrack_impl_, left_to_send, timestamp_ns_new);
diff --git a/chromecast/media/cma/backend/android/java/src/org/chromium/chromecast/cma/backend/android/AudioSinkAudioTrackImpl.java b/chromecast/media/cma/backend/android/java/src/org/chromium/chromecast/cma/backend/android/AudioSinkAudioTrackImpl.java index 2529946..c5eb08ab 100644 --- a/chromecast/media/cma/backend/android/java/src/org/chromium/chromecast/cma/backend/android/AudioSinkAudioTrackImpl.java +++ b/chromecast/media/cma/backend/android/java/src/org/chromium/chromecast/cma/backend/android/AudioSinkAudioTrackImpl.java
@@ -477,8 +477,17 @@ long beforeMsecs = SystemClock.elapsedRealtime(); int bytesWritten; if (mUseHwAvSync) { - bytesWritten = mAudioTrack.write( - mPcmBuffer, sizeInBytes, AudioTrack.WRITE_BLOCKING, timestampNs); + // Hw av sync stream uses the timestamp in the audio buffer to do + // synchronization. Therefore we need to skip pushing audio data to + // Android AudioTrack if the timestamp is invalid. The audio buffer + // with no timestamp is usually the silence buffer pushed by cma + // backend, not the audio data pushed by the native application. + if (timestampNs == NO_TIMESTAMP) { + bytesWritten = sizeInBytes; + } else { + bytesWritten = mAudioTrack.write( + mPcmBuffer, sizeInBytes, AudioTrack.WRITE_BLOCKING, timestampNs); + } } else { bytesWritten = mAudioTrack.write(mPcmBuffer, sizeInBytes, AudioTrack.WRITE_BLOCKING); } @@ -581,7 +590,12 @@ if (!haveValidRefPoint()) { // No timestamp available yet, just put dummy values and return. mRenderingDelayBuffer.putLong(0, 0); - mRenderingDelayBuffer.putLong(8, NO_TIMESTAMP); + // Hw av sync stream uses the timestamp in the audio buffer instead + // of the reported rendering delay to do synchronization. Therefore + // it is safe to report zero rendering delay when it is not + // available. + mRenderingDelayBuffer.putLong( + 8, mUseHwAvSync ? convertNsecsToUsecs(System.nanoTime()) : NO_TIMESTAMP); mLastRenderingDelayUsecs = NO_TIMESTAMP; return; }
diff --git a/chromeos/chromeos_strings.grd b/chromeos/chromeos_strings.grd index a986ad6e..375c48a 100644 --- a/chromeos/chromeos_strings.grd +++ b/chromeos/chromeos_strings.grd
@@ -2193,6 +2193,9 @@ <message name="IDS_PERSONALIZATION_APP_AVATAR_TAKE_PHOTO" desc="Label for the button to take a webcam photo"> Take a photo </message> + <message name="IDS_PERSONALIZATION_APP_AVATAR_ARIA_LABEL_WEBCAM_VIDEO" desc="Aria label for the webcam video feed"> + Webcam video feed + </message> <message name="IDS_PERSONALIZATION_APP_AVATAR_TAKE_VIDEO" desc="Label for the button to take a short webcam video"> Create a looping video </message> @@ -2211,6 +2214,10 @@ <message name="IDS_PERSONALIZATION_APP_ARIA_ANNOUNCE_AVATAR_CHANGED" desc="Text read out by the screen reader after the user avatar image changes"> Avatar image changed </message> + <message name="IDS_PERSONALIZATION_APP_AVATAR_ARIA_LABEL_CLOSE_CAMERA" desc="Aria label for the button to close the webcam interface modal"> + Close the camera + </message> + <message name="IDS_PERSONALIZATION_APP_SCREENSAVER_LABEL" desc="Label for the Screensaver page in Personalization app."> Screensaver </message>
diff --git a/chromeos/chromeos_strings_grd/IDS_PERSONALIZATION_APP_AVATAR_ARIA_LABEL_CLOSE_CAMERA.png.sha1 b/chromeos/chromeos_strings_grd/IDS_PERSONALIZATION_APP_AVATAR_ARIA_LABEL_CLOSE_CAMERA.png.sha1 new file mode 100644 index 0000000..242cc0d --- /dev/null +++ b/chromeos/chromeos_strings_grd/IDS_PERSONALIZATION_APP_AVATAR_ARIA_LABEL_CLOSE_CAMERA.png.sha1
@@ -0,0 +1 @@ +0142a6f3122f452092c2a447da3217fd3b16119c \ No newline at end of file
diff --git a/chromeos/chromeos_strings_grd/IDS_PERSONALIZATION_APP_AVATAR_ARIA_LABEL_WEBCAM_VIDEO.png.sha1 b/chromeos/chromeos_strings_grd/IDS_PERSONALIZATION_APP_AVATAR_ARIA_LABEL_WEBCAM_VIDEO.png.sha1 new file mode 100644 index 0000000..242cc0d --- /dev/null +++ b/chromeos/chromeos_strings_grd/IDS_PERSONALIZATION_APP_AVATAR_ARIA_LABEL_WEBCAM_VIDEO.png.sha1
@@ -0,0 +1 @@ +0142a6f3122f452092c2a447da3217fd3b16119c \ No newline at end of file
diff --git a/chromeos/dbus/fwupd/fwupd_client.cc b/chromeos/dbus/fwupd/fwupd_client.cc index fc9a02a..a1b26e3e 100644 --- a/chromeos/dbus/fwupd/fwupd_client.cc +++ b/chromeos/dbus/fwupd/fwupd_client.cc
@@ -22,6 +22,10 @@ namespace { +// This enum should match the UpdatePriority enum here: +// ash/webui/firmware_update_ui/mojom/firmware_update.mojom +enum UpdatePriority { kLow, kMedium, kHigh, kCritical }; + FwupdClient* g_instance = nullptr; const char kCabFileExtension[] = ".cab"; @@ -277,23 +281,29 @@ << " is missing its description text."; } + // If priority isn't specified we use default of low priority + int priority_value = UpdatePriority::kLow; + if (priority) { + priority_value = priority->GetInt(); + } else { + LOG(WARNING) + << "Device: " << device_id + << " is missing its priority field, using default of low priority."; + } + const bool success = - version && priority && !filepath.empty() && !sha_checksum.empty(); + version && !filepath.empty() && !sha_checksum.empty(); // TODO(michaelcheco): Confirm that this is the expected behavior. if (success) { VLOG(1) << "fwupd: Found update version for device: " << device_id << " with version: " << version->GetString(); updates.emplace_back(version->GetString(), description_value, - priority->GetInt(), filepath, sha_checksum); + priority_value, filepath, sha_checksum); } else { if (!version) { LOG(ERROR) << "Device: " << device_id << " is missing its version field."; } - if (!priority) { - LOG(ERROR) << "Device: " << device_id - << " is missing its priority field."; - } if (!uri) { LOG(ERROR) << "Device: " << device_id << " is missing its URI field."; }
diff --git a/chromeos/dbus/fwupd/fwupd_client_unittest.cc b/chromeos/dbus/fwupd/fwupd_client_unittest.cc index 22f5ee8..f03ad2b 100644 --- a/chromeos/dbus/fwupd/fwupd_client_unittest.cc +++ b/chromeos/dbus/fwupd/fwupd_client_unittest.cc
@@ -223,8 +223,7 @@ CHECK_EQ(expected_description_, (*updates)[0].description); // This value is returned by DBus as a uint32_t and is added to a dictionary // that doesn't support unsigned numbers. So it needs to be casted to int. - CHECK_EQ(static_cast<int>(kFakeUpdatePriorityForTesting), - (*updates)[0].priority); + CHECK_EQ(expected_priority_, (*updates)[0].priority); CHECK_EQ(kFakeUpdateUriForTesting, (*updates)[0].filepath.value()); CHECK_EQ(expected_checksum_, (*updates)[0].checksum); } @@ -241,6 +240,10 @@ expected_description_ = description; } + void SetExpectedPriority(const int priority) { + expected_priority_ = priority; + } + void SetExpectNoUpdates(bool no_updates) { expect_no_updates_ = no_updates; } void CheckPropertyChanged(FwupdProperties* properties) { @@ -315,6 +318,7 @@ std::string expected_checksum_; std::string expected_description_; + int expected_priority_ = kFakeUpdatePriorityForTesting; }; // TODO (swifton): Rewrite this test with an observer when it's available. @@ -429,6 +433,65 @@ base::RunLoop().RunUntilIdle(); } +TEST_F(FwupdClientTest, RequestUpgradesWithoutPriority) { + // The observer will check that the update description is parsed and passed + // correctly. + MockObserver observer; + EXPECT_CALL(observer, OnUpdateListResponse(_, _)) + .Times(1) + .WillRepeatedly(Invoke(this, &FwupdClientTest::CheckUpdates)); + fwupd_client_->AddObserver(&observer); + + EXPECT_CALL(*proxy_, DoCallMethodWithErrorResponse(_, _, _)) + .WillRepeatedly(Invoke(this, &FwupdClientTest::OnMethodCalled)); + + auto response = dbus::Response::CreateEmpty(); + + dbus::MessageWriter response_writer(response.get()); + dbus::MessageWriter response_array_writer(nullptr); + dbus::MessageWriter device_array_writer(nullptr); + dbus::MessageWriter dict_writer(nullptr); + + // The response is an array of arrays of dictionaries. Each dictionary is one + // update description. + response_writer.OpenArray("a{sv}", &response_array_writer); + response_array_writer.OpenArray("{sv}", &device_array_writer); + + device_array_writer.OpenDictEntry(&dict_writer); + dict_writer.AppendString(kDescriptionKey); + dict_writer.AppendVariantOfString(kFakeUpdateDescriptionForTesting); + device_array_writer.CloseContainer(&dict_writer); + SetExpectedDescription(kFakeUpdateDescriptionForTesting); + + device_array_writer.OpenDictEntry(&dict_writer); + dict_writer.AppendString(kVersionKey); + dict_writer.AppendVariantOfString(kFakeUpdateVersionForTesting); + device_array_writer.CloseContainer(&dict_writer); + + device_array_writer.OpenDictEntry(&dict_writer); + dict_writer.AppendString(kUriKey); + dict_writer.AppendVariantOfString(kFakeUpdateUriForTesting); + device_array_writer.CloseContainer(&dict_writer); + + device_array_writer.OpenDictEntry(&dict_writer); + dict_writer.AppendString(kChecksumKey); + dict_writer.AppendVariantOfString(kFakeSha256ForTesting); + device_array_writer.CloseContainer(&dict_writer); + SetExpectedChecksum(kFakeSha256ForTesting); + + response_array_writer.CloseContainer(&device_array_writer); + response_writer.CloseContainer(&response_array_writer); + + AddDbusMethodCallResultSimulation(std::move(response), nullptr); + + // Since priority is not specified, we want to use lowest priority + SetExpectedPriority(0); + + fwupd_client_->RequestUpdates(kFakeDeviceIdForTesting); + + base::RunLoop().RunUntilIdle(); +} + TEST_F(FwupdClientTest, TwoChecksumAvailable) { // The observer will check that the update description is parsed and passed // correctly.
diff --git a/components/BUILD.gn b/components/BUILD.gn index 318d69f0..0c29eeaf 100644 --- a/components/BUILD.gn +++ b/components/BUILD.gn
@@ -274,7 +274,10 @@ deps += [ "//components/nacl/browser:unit_tests" ] } - if (!is_fuchsia) { + if (is_fuchsia) { + deps += [ "//components/fuchsia_legacymetrics:unit_tests" ] + } else { + # TODO(crbug.com/1290514): Enable all relevant tests on Fuchsia too. deps += [ "//components/browser_sync:unit_tests", "//components/send_tab_to_self:unit_tests",
diff --git a/components/autofill/core/browser/autofill_client.cc b/components/autofill/core/browser/autofill_client.cc index d2bd88f..f04b44c 100644 --- a/components/autofill/core/browser/autofill_client.cc +++ b/components/autofill/core/browser/autofill_client.cc
@@ -41,10 +41,15 @@ return version_info::Channel::UNKNOWN; } +base::WeakPtr<MerchantPromoCodeManager> +AutofillClient::GetMerchantPromoCodeManager() { + return nullptr; +} + std::unique_ptr<SingleFieldFormFillRouter> AutofillClient::GetSingleFieldFormFillRouter() { return std::make_unique<SingleFieldFormFillRouter>( - GetAutocompleteHistoryManager()); + GetAutocompleteHistoryManager(), GetMerchantPromoCodeManager()); } AutofillOfferManager* AutofillClient::GetAutofillOfferManager() {
diff --git a/components/autofill/core/browser/autofill_client.h b/components/autofill/core/browser/autofill_client.h index b71fed7..3e8954e1 100644 --- a/components/autofill/core/browser/autofill_client.h +++ b/components/autofill/core/browser/autofill_client.h
@@ -73,6 +73,7 @@ class FormStructure; class LogManager; class MigratableCreditCard; +class MerchantPromoCodeManager; class OtpUnmaskDelegate; enum class OtpUnmaskResult; class PersonalDataManager; @@ -317,9 +318,13 @@ // Gets the PersonalDataManager instance associated with the client. virtual PersonalDataManager* GetPersonalDataManager() = 0; - // Gets the AutocompleteHistoryManager instance associate with the client. + // Gets the AutocompleteHistoryManager instance associated with the client. virtual AutocompleteHistoryManager* GetAutocompleteHistoryManager() = 0; + // Gets the MerchantPromoCodeManager instance associated with the + // client (can be null for unsupported platforms). + virtual base::WeakPtr<MerchantPromoCodeManager> GetMerchantPromoCodeManager(); + // Creates and returns a SingleFieldFormFillRouter using the // AutocompleteHistoryManager instance associated with the client. std::unique_ptr<SingleFieldFormFillRouter> GetSingleFieldFormFillRouter();
diff --git a/components/autofill/core/browser/autofill_external_delegate.cc b/components/autofill/core/browser/autofill_external_delegate.cc index 81214981..5635a170 100644 --- a/components/autofill/core/browser/autofill_external_delegate.cc +++ b/components/autofill/core/browser/autofill_external_delegate.cc
@@ -218,7 +218,8 @@ // Only preview the data if it is a profile or a virtual card. if (frontend_id > 0) { FillAutofillFormData(frontend_id, true); - } else if (frontend_id == POPUP_ITEM_ID_AUTOCOMPLETE_ENTRY) { + } else if (frontend_id == POPUP_ITEM_ID_AUTOCOMPLETE_ENTRY || + frontend_id == POPUP_ITEM_ID_MERCHANT_PROMO_CODE_ENTRY) { driver_->RendererShouldPreviewFieldWithValue(query_field_.global_id(), value); } else if (frontend_id == POPUP_ITEM_ID_VIRTUAL_CREDIT_CARD_ENTRY) { @@ -248,8 +249,10 @@ } else if (frontend_id == POPUP_ITEM_ID_DATALIST_ENTRY) { driver_->RendererShouldAcceptDataListSuggestion(query_field_.global_id(), value); - } else if (frontend_id == POPUP_ITEM_ID_AUTOCOMPLETE_ENTRY) { - // User selected an Autocomplete, so we fill directly. + } else if (frontend_id == POPUP_ITEM_ID_AUTOCOMPLETE_ENTRY || + frontend_id == POPUP_ITEM_ID_MERCHANT_PROMO_CODE_ENTRY) { + // User selected an Autocomplete or Merchant Promo Code field, so we fill + // directly. driver_->RendererShouldFillFieldWithValue(query_field_.global_id(), value); AutofillMetrics::LogAutocompleteSuggestionAcceptedIndex(position); manager_->OnSingleFieldSuggestionSelected(value, frontend_id);
diff --git a/components/autofill/core/browser/autofill_external_delegate_unittest.cc b/components/autofill/core/browser/autofill_external_delegate_unittest.cc index 01a0c39..7ecee1b 100644 --- a/components/autofill/core/browser/autofill_external_delegate_unittest.cc +++ b/components/autofill/core/browser/autofill_external_delegate_unittest.cc
@@ -93,7 +93,6 @@ (override)); MOCK_METHOD(void, HideAutofillPopup, (PopupHidingReason), (override)); MOCK_METHOD(void, ExecuteCommand, (int), (override)); - // Mock the client query ID check. bool IsQueryIDRelevant(int query_id) { return query_id == kRecentQueryId; } }; @@ -615,6 +614,45 @@ 0); } +// Test that the Autofill delegate still allows previewing and filling +// specifically of the negative ID for POPUP_ITEM_ID_MERCHANT_PROMO_CODE_ENTRY. +TEST_F(AutofillExternalDelegateUnitTest, + ExternalDelegateFillsMerchantPromoCodeEntry) { + IssueOnQuery(kRecentQueryId); + + AutofillClient::PopupOpenArgs open_args; + EXPECT_CALL(autofill_client_, ShowAutofillPopup) + .WillOnce(testing::SaveArg<0>(&open_args)); + + // This should call ShowAutofillPopup. + std::vector<Suggestion> suggestions; + suggestions.emplace_back(); + std::u16string promo_code_value = u"PROMOCODE1234"; + suggestions[0].main_text.value = promo_code_value; + suggestions[0].label = u"12.34% off your purchase!"; + suggestions[0].frontend_id = POPUP_ITEM_ID_MERCHANT_PROMO_CODE_ENTRY; + external_delegate_->OnSuggestionsReturned( + kRecentQueryId, suggestions, /*autoselect_first_suggestion=*/false); + + // The enums must be cast to ints to prevent compile errors on linux_rel. + EXPECT_THAT(open_args.suggestions, + SuggestionVectorIdsAre(testing::ElementsAre( + static_cast<int>(POPUP_ITEM_ID_MERCHANT_PROMO_CODE_ENTRY)))); + + EXPECT_CALL(*autofill_driver_, RendererShouldClearPreviewedForm()).Times(1); + EXPECT_CALL(*autofill_driver_, + RendererShouldPreviewFieldWithValue(field_id_, promo_code_value)); + external_delegate_->DidSelectSuggestion( + promo_code_value, POPUP_ITEM_ID_MERCHANT_PROMO_CODE_ENTRY, ""); + EXPECT_CALL(autofill_client_, + HideAutofillPopup(PopupHidingReason::kAcceptSuggestion)); + EXPECT_CALL(*autofill_driver_, + RendererShouldFillFieldWithValue(field_id_, promo_code_value)); + external_delegate_->DidAcceptSuggestion( + promo_code_value, POPUP_ITEM_ID_MERCHANT_PROMO_CODE_ENTRY, std::string(), + 0); +} + // Test that the ClearPreview call is only sent if the form was being previewed // (i.e. it isn't autofilling a password). TEST_F(AutofillExternalDelegateUnitTest, ExternalDelegateClearPreviewedForm) {
diff --git a/components/autofill/core/browser/browser_autofill_manager.cc b/components/autofill/core/browser/browser_autofill_manager.cc index 98285ae..e298448 100644 --- a/components/autofill/core/browser/browser_autofill_manager.cc +++ b/components/autofill/core/browser/browser_autofill_manager.cc
@@ -678,7 +678,7 @@ std::unique_ptr<FormStructure> submitted_form = ValidateSubmittedForm(form); if (!submitted_form) { single_field_form_fill_router_->OnWillSubmitForm( - form, client()->IsAutocompleteEnabled()); + form, submitted_form.get(), client()->IsAutocompleteEnabled()); return; } @@ -721,7 +721,8 @@ } } single_field_form_fill_router_->OnWillSubmitForm( - form_for_autocomplete, client()->IsAutocompleteEnabled()); + form_for_autocomplete, submitted_form.get(), + client()->IsAutocompleteEnabled()); if (IsAutofillProfileEnabled()) { address_form_event_logger_->OnWillSubmitForm(sync_state_, *submitted_form);
diff --git a/components/autofill/core/browser/browser_autofill_manager_unittest.cc b/components/autofill/core/browser/browser_autofill_manager_unittest.cc index 96d77e96..9ff171b 100644 --- a/components/autofill/core/browser/browser_autofill_manager_unittest.cc +++ b/components/autofill/core/browser/browser_autofill_manager_unittest.cc
@@ -42,6 +42,7 @@ #include "components/autofill/core/browser/geo/alternative_state_name_map_test_utils.h" #include "components/autofill/core/browser/metrics/form_events/form_events.h" #include "components/autofill/core/browser/mock_autocomplete_history_manager.h" +#include "components/autofill/core/browser/mock_merchant_promo_code_manager.h" #include "components/autofill/core/browser/mock_single_field_form_fill_router.h" #include "components/autofill/core/browser/payments/test_credit_card_save_manager.h" #include "components/autofill/core/browser/payments/test_credit_card_save_strike_database.h" @@ -381,6 +382,10 @@ /*profile_database=*/database_, /*pref_service=*/autofill_client_.GetPrefs(), /*is_off_the_record=*/false); + merchant_promo_code_manager_ = + std::make_unique<NiceMock<MockMerchantPromoCodeManager>>(); + merchant_promo_code_manager_->Init(&personal_data(), + /*is_off_the_record=*/false); autofill_driver_ = std::make_unique<testing::NiceMock<MockAutofillDriver>>(); @@ -405,7 +410,8 @@ auto single_field_form_fill_router = std::make_unique<NiceMock<MockSingleFieldFormFillRouter>>( - autocomplete_history_manager_.get()); + autocomplete_history_manager_.get(), + merchant_promo_code_manager_.get()->GetWeakPtr()); single_field_form_fill_router_ = single_field_form_fill_router.get(); browser_autofill_manager_->set_single_field_form_fill_router_for_test( std::move(single_field_form_fill_router)); @@ -739,7 +745,8 @@ auto single_field_form_fill_router = std::make_unique<NiceMock<MockSingleFieldFormFillRouter>>( - autocomplete_history_manager_.get()); + autocomplete_history_manager_.get(), + merchant_promo_code_manager_.get()->GetWeakPtr()); single_field_form_fill_router_ = single_field_form_fill_router.get(); browser_autofill_manager_->set_single_field_form_fill_router_for_test( std::move(single_field_form_fill_router)); @@ -759,6 +766,7 @@ scoped_refptr<AutofillWebDataService> database_; raw_ptr<MockAutofillDownloadManager> download_manager_; std::unique_ptr<MockAutocompleteHistoryManager> autocomplete_history_manager_; + std::unique_ptr<MockMerchantPromoCodeManager> merchant_promo_code_manager_; raw_ptr<MockSingleFieldFormFillRouter> single_field_form_fill_router_; base::test::ScopedFeatureList scoped_feature_list_; raw_ptr<TestStrikeDatabase> strike_database_; @@ -5621,7 +5629,7 @@ FormData form; test::CreateTestAddressFormData(&form); - EXPECT_CALL(*(single_field_form_fill_router_), OnWillSubmitForm(_, true)); + EXPECT_CALL(*(single_field_form_fill_router_), OnWillSubmitForm(_, _, true)); FormSubmitted(form); } @@ -7385,7 +7393,7 @@ TEST_P(BrowserAutofillManagerStructuredProfileTest, DontSaveCvcInAutocompleteHistory) { FormData form_seen_by_ahm; - EXPECT_CALL(*(single_field_form_fill_router_), OnWillSubmitForm(_, true)) + EXPECT_CALL(*(single_field_form_fill_router_), OnWillSubmitForm(_, _, true)) .WillOnce(SaveArg<0>(&form_seen_by_ahm)); FormData form;
diff --git a/components/autofill/core/browser/merchant_promo_code_manager.cc b/components/autofill/core/browser/merchant_promo_code_manager.cc index e948ee48..b81cdb5 100644 --- a/components/autofill/core/browser/merchant_promo_code_manager.cc +++ b/components/autofill/core/browser/merchant_promo_code_manager.cc
@@ -9,7 +9,6 @@ #include "components/autofill/core/browser/data_model/autofill_offer_data.h" #include "components/autofill/core/browser/personal_data_manager.h" #include "components/autofill/core/browser/suggestions_context.h" -#include "components/autofill/core/browser/ui/popup_item_ids.h" namespace autofill { @@ -55,7 +54,9 @@ void MerchantPromoCodeManager::OnSingleFieldSuggestionSelected( const std::u16string& value, - int frontend_id) {} + int frontend_id) { + // TODO(crbug.com/1190334): Add promo code suggestion accepted metrics here. +} void MerchantPromoCodeManager::Init( raw_ptr<PersonalDataManager> personal_data_manager,
diff --git a/components/autofill/core/browser/metrics/autofill_metrics.cc b/components/autofill/core/browser/metrics/autofill_metrics.cc index 9b596db6..ab4e3c8 100644 --- a/components/autofill/core/browser/metrics/autofill_metrics.cc +++ b/components/autofill/core/browser/metrics/autofill_metrics.cc
@@ -1372,33 +1372,11 @@ const base::TimeDelta& duration, AutofillClient::PaymentsRpcResult result, AutofillClient::PaymentsRpcCardType card_type) { - std::string result_suffix; - - switch (result) { - case AutofillClient::PaymentsRpcResult::kSuccess: - result_suffix = "Success"; - break; - case AutofillClient::PaymentsRpcResult::kTryAgainFailure: - case AutofillClient::PaymentsRpcResult::kPermanentFailure: - result_suffix = "Failure"; - break; - case AutofillClient::PaymentsRpcResult::kNetworkError: - result_suffix = "NetworkError"; - break; - case AutofillClient::PaymentsRpcResult::kVcnRetrievalTryAgainFailure: - case AutofillClient::PaymentsRpcResult::kVcnRetrievalPermanentFailure: - result_suffix = "VcnRetrievalFailure"; - break; - case AutofillClient::PaymentsRpcResult::kNone: - NOTREACHED(); - return; - } - base::UmaHistogramLongTimes("Autofill.BetterAuth.CardUnmaskDuration.Fido", duration); base::UmaHistogramLongTimes("Autofill.BetterAuth.CardUnmaskDuration.Fido" + - GetCreditCardTypeSuffix(card_type) + "." + - result_suffix, + GetCreditCardTypeSuffix(card_type) + + PaymentsRpcResultToMetricsSuffix(result), duration); } @@ -3556,4 +3534,30 @@ is_same); } +const std::string PaymentsRpcResultToMetricsSuffix( + AutofillClient::PaymentsRpcResult result) { + std::string result_suffix; + + switch (result) { + case AutofillClient::PaymentsRpcResult::kSuccess: + result_suffix = ".Success"; + break; + case AutofillClient::PaymentsRpcResult::kTryAgainFailure: + case AutofillClient::PaymentsRpcResult::kPermanentFailure: + result_suffix = ".Failure"; + break; + case AutofillClient::PaymentsRpcResult::kNetworkError: + result_suffix = ".NetworkError"; + break; + case AutofillClient::PaymentsRpcResult::kVcnRetrievalTryAgainFailure: + case AutofillClient::PaymentsRpcResult::kVcnRetrievalPermanentFailure: + result_suffix = ".VcnRetrievalFailure"; + break; + case AutofillClient::PaymentsRpcResult::kNone: + NOTREACHED(); + } + + return result_suffix; +} + } // namespace autofill
diff --git a/components/autofill/core/browser/metrics/autofill_metrics.h b/components/autofill/core/browser/metrics/autofill_metrics.h index 9dee313..f0cf0d7 100644 --- a/components/autofill/core/browser/metrics/autofill_metrics.h +++ b/components/autofill/core/browser/metrics/autofill_metrics.h
@@ -2129,5 +2129,8 @@ AutofillMetrics::AutofilledFieldUserEditingStatusMetric metric); #endif +const std::string PaymentsRpcResultToMetricsSuffix( + AutofillClient::PaymentsRpcResult result); + } // namespace autofill #endif // COMPONENTS_AUTOFILL_CORE_BROWSER_METRICS_AUTOFILL_METRICS_H_
diff --git a/components/autofill/core/browser/metrics/payments/virtual_card_enrollment_metrics.cc b/components/autofill/core/browser/metrics/payments/virtual_card_enrollment_metrics.cc index c10b82f..429e4640 100644 --- a/components/autofill/core/browser/metrics/payments/virtual_card_enrollment_metrics.cc +++ b/components/autofill/core/browser/metrics/payments/virtual_card_enrollment_metrics.cc
@@ -9,6 +9,7 @@ #include "base/strings/strcat.h" #include "base/strings/string_util.h" #include "base/time/time.h" +#include "components/autofill/core/browser/metrics/autofill_metrics.h" #include "components/autofill/core/browser/payments/virtual_card_enrollment_flow.h" namespace autofill { @@ -73,6 +74,17 @@ succeeded); } +void LogGetDetailsForEnrollmentRequestLatency( + VirtualCardEnrollmentSource source, + AutofillClient::PaymentsRpcResult result, + base::TimeDelta latency) { + base::UmaHistogramMediumTimes( + "Autofill.VirtualCard.GetDetailsForEnrollment.Latency." + + VirtualCardEnrollmentSourceToMetricSuffix(source) + + PaymentsRpcResultToMetricsSuffix(result), + latency); +} + void LogUpdateVirtualCardEnrollmentRequestAttempt( VirtualCardEnrollmentSource source, VirtualCardEnrollmentRequestType type) {
diff --git a/components/autofill/core/browser/metrics/payments/virtual_card_enrollment_metrics.h b/components/autofill/core/browser/metrics/payments/virtual_card_enrollment_metrics.h index a6dd846..d5d78fa3 100644 --- a/components/autofill/core/browser/metrics/payments/virtual_card_enrollment_metrics.h +++ b/components/autofill/core/browser/metrics/payments/virtual_card_enrollment_metrics.h
@@ -7,6 +7,7 @@ #include <string> +#include "components/autofill/core/browser/autofill_client.h" #include "components/autofill/core/browser/payments/virtual_card_enrollment_flow.h" namespace base { @@ -100,6 +101,10 @@ VirtualCardEnrollmentSource source); void LogGetDetailsForEnrollmentRequestResult(VirtualCardEnrollmentSource source, bool succeeded); +void LogGetDetailsForEnrollmentRequestLatency( + VirtualCardEnrollmentSource source, + AutofillClient::PaymentsRpcResult result, + base::TimeDelta latency); // UpdateVirtualCardEnrollmentRequest related metrics. Attempts and results // should be 1:1 mapping.
diff --git a/components/autofill/core/browser/mock_single_field_form_fill_router.cc b/components/autofill/core/browser/mock_single_field_form_fill_router.cc index 286c6d9c7..ff94d51 100644 --- a/components/autofill/core/browser/mock_single_field_form_fill_router.cc +++ b/components/autofill/core/browser/mock_single_field_form_fill_router.cc
@@ -6,8 +6,10 @@ namespace autofill { MockSingleFieldFormFillRouter::MockSingleFieldFormFillRouter( - AutocompleteHistoryManager* autocomplete_history_manager) - : SingleFieldFormFillRouter(autocomplete_history_manager) {} + AutocompleteHistoryManager* autocomplete_history_manager, + base::WeakPtr<MerchantPromoCodeManager> merchant_promo_code_manager) + : SingleFieldFormFillRouter(autocomplete_history_manager, + merchant_promo_code_manager) {} MockSingleFieldFormFillRouter::~MockSingleFieldFormFillRouter() = default;
diff --git a/components/autofill/core/browser/mock_single_field_form_fill_router.h b/components/autofill/core/browser/mock_single_field_form_fill_router.h index 62ac968..30b8f3df 100644 --- a/components/autofill/core/browser/mock_single_field_form_fill_router.h +++ b/components/autofill/core/browser/mock_single_field_form_fill_router.h
@@ -14,10 +14,17 @@ class MockSingleFieldFormFillRouter : public SingleFieldFormFillRouter { public: explicit MockSingleFieldFormFillRouter( - AutocompleteHistoryManager* autocomplete_history_manager); + AutocompleteHistoryManager* autocomplete_history_manager, + base::WeakPtr<MerchantPromoCodeManager> merchant_promo_code_manager); ~MockSingleFieldFormFillRouter() override; MOCK_METHOD(void, + OnWillSubmitForm, + (const FormData& form, + raw_ptr<const FormStructure> form_structure, + bool is_autocomplete_enabled), + (override)); + MOCK_METHOD(void, OnGetSingleFieldSuggestions, (int query_id, bool is_autocomplete_enabled, @@ -29,8 +36,9 @@ const SuggestionsContext& context), (override)); MOCK_METHOD(void, - OnWillSubmitForm, - (const FormData& form, bool is_autocomplete_enabled), + OnWillSubmitFormWithFields, + (const std::vector<FormFieldData>& fields, + bool is_autocomplete_enabled), (override)); MOCK_METHOD(void, CancelPendingQueries,
diff --git a/components/autofill/core/browser/payments/virtual_card_enrollment_manager.cc b/components/autofill/core/browser/payments/virtual_card_enrollment_manager.cc index cc660376..da83272 100644 --- a/components/autofill/core/browser/payments/virtual_card_enrollment_manager.cc +++ b/components/autofill/core/browser/payments/virtual_card_enrollment_manager.cc
@@ -351,6 +351,9 @@ state_.virtual_card_enrollment_fields.credit_card.instrument_id(); request_details.source = state_.virtual_card_enrollment_fields.virtual_card_enrollment_source; + + get_details_for_enrollment_request_sent_timestamp_ = AutofillClock::Now(); + payments_client_->GetVirtualCardEnrollmentDetails( request_details, base::BindOnce( @@ -366,6 +369,15 @@ response) { enroll_response_details_received_ = true; + if (get_details_for_enrollment_request_sent_timestamp_.has_value()) { + LogGetDetailsForEnrollmentRequestLatency( + state_.virtual_card_enrollment_fields.virtual_card_enrollment_source, + result, + AutofillClock::Now() - + get_details_for_enrollment_request_sent_timestamp_.value()); + get_details_for_enrollment_request_sent_timestamp_.reset(); + } + LogGetDetailsForEnrollmentRequestResult( state_.virtual_card_enrollment_fields.virtual_card_enrollment_source, /*succeeded=*/result == AutofillClient::PaymentsRpcResult::kSuccess);
diff --git a/components/autofill/core/browser/payments/virtual_card_enrollment_manager.h b/components/autofill/core/browser/payments/virtual_card_enrollment_manager.h index ea350f5..0e533bc8 100644 --- a/components/autofill/core/browser/payments/virtual_card_enrollment_manager.h +++ b/components/autofill/core/browser/payments/virtual_card_enrollment_manager.h
@@ -323,6 +323,9 @@ // metric. |save_card_bubble_accepted_timestamp_| will then be reset. absl::optional<base::Time> save_card_bubble_accepted_timestamp_; + // The timestamp when a GetDetailsForEnrollment request is sent. + absl::optional<base::Time> get_details_for_enrollment_request_sent_timestamp_; + base::WeakPtrFactory<VirtualCardEnrollmentManager> weak_ptr_factory_{this}; };
diff --git a/components/autofill/core/browser/payments/virtual_card_enrollment_manager_unittest.cc b/components/autofill/core/browser/payments/virtual_card_enrollment_manager_unittest.cc index 4647847..005cf37 100644 --- a/components/autofill/core/browser/payments/virtual_card_enrollment_manager_unittest.cc +++ b/components/autofill/core/browser/payments/virtual_card_enrollment_manager_unittest.cc
@@ -14,6 +14,7 @@ #include "components/autofill/core/browser/autofill_test_utils.h" #include "components/autofill/core/browser/data_model/credit_card.h" #include "components/autofill/core/browser/data_model/credit_card_art_image.h" +#include "components/autofill/core/browser/metrics/autofill_metrics.h" #include "components/autofill/core/browser/metrics/payments/virtual_card_enrollment_metrics.h" #include "components/autofill/core/browser/payments/payments_requests/update_virtual_card_enrollment_request.h" #include "components/autofill/core/browser/payments/payments_util.h" @@ -253,6 +254,7 @@ TEST_F(VirtualCardEnrollmentManagerTest, OnDidGetDetailsForEnrollResponse) { base::HistogramTester histogram_tester; + TestAutofillClock test_autofill_clock(AutofillClock::Now()); const TestLegalMessageLine google_legal_message = TestLegalMessageLine("google_test_legal_message"); const TestLegalMessageLine issuer_legal_message = @@ -269,6 +271,9 @@ #else for (bool make_image_present : {true, false}) { #endif // BUILDFLAG(IS_IOS) + virtual_card_enrollment_manager_ + ->get_details_for_enrollment_request_sent_timestamp_ = + AutofillClock::Now(); payments::PaymentsClient::GetDetailsForEnrollmentResponseDetails response = std::move(SetUpOnDidGetDetailsForEnrollResponse( google_legal_message, issuer_legal_message, make_image_present)); @@ -285,6 +290,8 @@ network_image); } + test_autofill_clock.Advance(base::Milliseconds(5)); + virtual_card_enrollment_manager_->OnDidGetDetailsForEnrollResponse( AutofillClient::PaymentsRpcResult::kSuccess, response); @@ -313,6 +320,12 @@ "Autofill.VirtualCard.GetDetailsForEnrollment.Result." + VirtualCardEnrollmentSourceToMetricSuffix(source), /*sample=*/true, make_image_present ? 1 : 2); + histogram_tester.ExpectBucketCount( + "Autofill.VirtualCard.GetDetailsForEnrollment.Latency." + + VirtualCardEnrollmentSourceToMetricSuffix(source) + + PaymentsRpcResultToMetricsSuffix( + AutofillClient::PaymentsRpcResult::kSuccess), + /*sample=*/5, make_image_present ? 1 : 2); } } }
diff --git a/components/autofill/core/browser/single_field_form_fill_router.cc b/components/autofill/core/browser/single_field_form_fill_router.cc index 1237aaa..1a242f3 100644 --- a/components/autofill/core/browser/single_field_form_fill_router.cc +++ b/components/autofill/core/browser/single_field_form_fill_router.cc
@@ -4,22 +4,49 @@ #include "components/autofill/core/browser/single_field_form_fill_router.h" +#include "components/autofill/core/browser/merchant_promo_code_manager.h" #include "components/autofill/core/browser/suggestions_context.h" -#include "components/autofill/core/common/autofill_payments_features.h" namespace autofill { SingleFieldFormFillRouter::SingleFieldFormFillRouter( - AutocompleteHistoryManager* autocomplete_history_manager) { - autocomplete_history_manager_ = autocomplete_history_manager->GetWeakPtr(); -} + AutocompleteHistoryManager* autocomplete_history_manager, + base::WeakPtr<MerchantPromoCodeManager> merchant_promo_code_manager) + : autocomplete_history_manager_(autocomplete_history_manager->GetWeakPtr()), + merchant_promo_code_manager_(merchant_promo_code_manager) {} SingleFieldFormFillRouter::~SingleFieldFormFillRouter() = default; -void SingleFieldFormFillRouter::OnWillSubmitForm(const FormData& form, - bool is_autocomplete_enabled) { +void SingleFieldFormFillRouter::OnWillSubmitForm( + const FormData& form, + raw_ptr<const FormStructure> form_structure, + bool is_autocomplete_enabled) { + if (form_structure) + DCHECK(form.fields.size() == form_structure->field_count()); + std::vector<FormFieldData> autocomplete_fields; + std::vector<FormFieldData> merchant_promo_code_fields; + autocomplete_fields.reserve(form.fields.size()); + merchant_promo_code_fields.reserve(form.fields.size()); + for (size_t i = 0; i < form.fields.size(); i++) { + // If |form_structure| is present, then the fields in |form_structure| and + // the fields in |form| should be 1:1. |form_structure| not being present + // indicates we may have fields that were not able to be parsed, so we route + // them to autocomplete functionality by default. + if (merchant_promo_code_manager_ && form_structure && + form_structure->field(i)->Type().GetStorableType() == + MERCHANT_PROMO_CODE) { + merchant_promo_code_fields.push_back(form.fields[i]); + } else { + autocomplete_fields.push_back(form.fields[i]); + } + } + + if (merchant_promo_code_manager_) { + merchant_promo_code_manager_->OnWillSubmitFormWithFields( + merchant_promo_code_fields, is_autocomplete_enabled); + } autocomplete_history_manager_->OnWillSubmitFormWithFields( - form.fields, is_autocomplete_enabled); + autocomplete_fields, is_autocomplete_enabled); } void SingleFieldFormFillRouter::OnGetSingleFieldSuggestions( @@ -31,9 +58,17 @@ const std::string& form_control_type, base::WeakPtr<SingleFieldFormFiller::SuggestionsHandler> handler, const SuggestionsContext& context) { - autocomplete_history_manager_->OnGetSingleFieldSuggestions( - query_id, is_autocomplete_enabled, autoselect_first_suggestion, name, - prefix, form_control_type, handler, context); + // Retrieving suggestions for a new field; select the appropriate filler. + if (merchant_promo_code_manager_ && context.focused_field && + context.focused_field->Type().GetStorableType() == MERCHANT_PROMO_CODE) { + merchant_promo_code_manager_->OnGetSingleFieldSuggestions( + query_id, is_autocomplete_enabled, autoselect_first_suggestion, name, + prefix, form_control_type, handler, context); + } else { + autocomplete_history_manager_->OnGetSingleFieldSuggestions( + query_id, is_autocomplete_enabled, autoselect_first_suggestion, name, + prefix, form_control_type, handler, context); + } } void SingleFieldFormFillRouter::OnWillSubmitFormWithFields( @@ -44,21 +79,35 @@ const SingleFieldFormFiller::SuggestionsHandler* handler) { if (autocomplete_history_manager_) autocomplete_history_manager_->CancelPendingQueries(handler); + if (merchant_promo_code_manager_) + merchant_promo_code_manager_->CancelPendingQueries(handler); } void SingleFieldFormFillRouter::OnRemoveCurrentSingleFieldSuggestion( const std::u16string& field_name, const std::u16string& value, int frontend_id) { - autocomplete_history_manager_->OnRemoveCurrentSingleFieldSuggestion( - field_name, value, frontend_id); + if (merchant_promo_code_manager_ && + frontend_id == POPUP_ITEM_ID_MERCHANT_PROMO_CODE_ENTRY) { + merchant_promo_code_manager_->OnRemoveCurrentSingleFieldSuggestion( + field_name, value, frontend_id); + } else { + autocomplete_history_manager_->OnRemoveCurrentSingleFieldSuggestion( + field_name, value, frontend_id); + } } void SingleFieldFormFillRouter::OnSingleFieldSuggestionSelected( const std::u16string& value, int frontend_id) { - autocomplete_history_manager_->OnSingleFieldSuggestionSelected(value, - frontend_id); + if (merchant_promo_code_manager_ && + frontend_id == POPUP_ITEM_ID_MERCHANT_PROMO_CODE_ENTRY) { + merchant_promo_code_manager_->OnSingleFieldSuggestionSelected(value, + frontend_id); + } else { + autocomplete_history_manager_->OnSingleFieldSuggestionSelected(value, + frontend_id); + } } } // namespace autofill
diff --git a/components/autofill/core/browser/single_field_form_fill_router.h b/components/autofill/core/browser/single_field_form_fill_router.h index b2811d8c..b398672c 100644 --- a/components/autofill/core/browser/single_field_form_fill_router.h +++ b/components/autofill/core/browser/single_field_form_fill_router.h
@@ -7,6 +7,8 @@ #include "base/memory/weak_ptr.h" #include "components/autofill/core/browser/autocomplete_history_manager.h" +#include "components/autofill/core/browser/form_structure.h" +#include "components/autofill/core/browser/merchant_promo_code_manager.h" #include "components/autofill/core/browser/single_field_form_filler.h" #include "components/autofill/core/common/form_data.h" @@ -20,7 +22,8 @@ class SingleFieldFormFillRouter : public SingleFieldFormFiller { public: explicit SingleFieldFormFillRouter( - AutocompleteHistoryManager* autocomplete_history_manager); + AutocompleteHistoryManager* autocomplete_history_manager, + base::WeakPtr<MerchantPromoCodeManager> merchant_promo_code_manager); ~SingleFieldFormFillRouter() override; SingleFieldFormFillRouter(const SingleFieldFormFillRouter&) = delete; SingleFieldFormFillRouter& operator=(const SingleFieldFormFillRouter&) = @@ -28,8 +31,13 @@ // Routes every field in a form to its correct SingleFieldFormFiller, calling // SingleFieldFormFiller::OnWillSubmitFormWithFields() with the vector of - // fields for that specific SingleFieldFormFiller. + // fields for that specific SingleFieldFormFiller. If |form_structure| is not + // nullptr, then the fields in |form| and |form_structure| should be 1:1. It + // is possible for |form_structure| to be nullptr while |form| has data, which + // means there were fields in the form that were not able to be parsed as + // autofill fields. virtual void OnWillSubmitForm(const FormData& form, + raw_ptr<const FormStructure> form_structure, bool is_autocomplete_enabled); // SingleFieldFormFiller overrides: @@ -53,8 +61,12 @@ int frontend_id) override; private: - // Available single field form fillers: + // Handles autocompleting single fields. base::WeakPtr<AutocompleteHistoryManager> autocomplete_history_manager_; + + // Handles autofilling merchant promo code fields (can be null for unsupported + // platforms). + base::WeakPtr<MerchantPromoCodeManager> merchant_promo_code_manager_; }; } // namespace autofill
diff --git a/components/autofill/core/browser/single_field_form_fill_router_unittest.cc b/components/autofill/core/browser/single_field_form_fill_router_unittest.cc index 3ae6fc5..bc028a6ae 100644 --- a/components/autofill/core/browser/single_field_form_fill_router_unittest.cc +++ b/components/autofill/core/browser/single_field_form_fill_router_unittest.cc
@@ -3,16 +3,31 @@ // found in the LICENSE file. #include "components/autofill/core/browser/single_field_form_fill_router.h" +#include "base/test/scoped_feature_list.h" #include "base/test/task_environment.h" #include "components/autofill/core/browser/autofill_test_utils.h" #include "components/autofill/core/browser/mock_autocomplete_history_manager.h" +#include "components/autofill/core/browser/mock_merchant_promo_code_manager.h" #include "components/autofill/core/browser/suggestions_context.h" +#include "components/autofill/core/browser/test_form_structure.h" +#include "components/autofill/core/browser/test_personal_data_manager.h" #include "components/autofill/core/browser/webdata/mock_autofill_webdata_service.h" +#include "components/autofill/core/common/autofill_features.h" +#include "components/autofill/core/common/autofill_payments_features.h" #include "components/autofill/core/common/autofill_prefs.h" #include "components/version_info/version_info.h" +#include "testing/gmock/include/gmock/gmock.h" +#include "testing/gtest/include/gtest/gtest.h" + +using testing::_; +using testing::DoAll; +using testing::SaveArg; namespace autofill { +using FieldPrediction = + AutofillQueryResponse::FormSuggestion::FieldSuggestion::FieldPrediction; + namespace { class MockSuggestionsHandler : public SingleFieldFormFiller::SuggestionsHandler { @@ -46,19 +61,27 @@ // Mock such that we don't trigger the cleanup. prefs_->SetInteger(prefs::kAutocompleteLastVersionRetentionPolicy, CHROME_VERSION_MAJOR); + personal_data_manager_ = std::make_unique<TestPersonalDataManager>(); web_data_service_ = base::MakeRefCounted<MockAutofillWebDataService>(); autocomplete_history_manager_ = std::make_unique<MockAutocompleteHistoryManager>(); autocomplete_history_manager_->Init(web_data_service_, prefs_.get(), false); + merchant_promo_code_manager_ = + std::make_unique<MockMerchantPromoCodeManager>(); + merchant_promo_code_manager_->Init(personal_data_manager_.get(), + /*is_off_the_record=*/false); single_field_form_fill_router_ = std::make_unique<SingleFieldFormFillRouter>( - autocomplete_history_manager_.get()); + autocomplete_history_manager_.get(), + merchant_promo_code_manager_.get()->GetWeakPtr()); } base::test::SingleThreadTaskEnvironment task_environment_; std::unique_ptr<SingleFieldFormFillRouter> single_field_form_fill_router_; + std::unique_ptr<TestPersonalDataManager> personal_data_manager_; scoped_refptr<MockAutofillWebDataService> web_data_service_; std::unique_ptr<MockAutocompleteHistoryManager> autocomplete_history_manager_; + std::unique_ptr<MockMerchantPromoCodeManager> merchant_promo_code_manager_; std::unique_ptr<PrefService> prefs_; }; @@ -78,24 +101,83 @@ SuggestionsContext()); } -// Ensure that the router routes to AutocompleteHistoryManager for this +// Ensure that the router routes to all SingleFieldFormFillers for this // OnWillSubmitForm call. TEST_F(SingleFieldFormFillRouterTest, - RouteToAutocompleteHistoryManager_OnWillSubmitForm) { - EXPECT_CALL(*autocomplete_history_manager_, OnWillSubmitFormWithFields); + RouteToAllSingleFieldFormFillers_OnWillSubmitForm) { + FormData form_data; + std::vector<FormFieldData> fields; + size_t number_of_fields_for_testing = 3; + for (size_t i = 0; i < number_of_fields_for_testing; i++) { + fields.emplace_back(); + } + +#if !BUILDFLAG(IS_IOS) + for (size_t i = 0; i < number_of_fields_for_testing; i++) { + fields.emplace_back(); + } +#endif // !BUILDFLAG(IS_IOS) + + form_data.fields = fields; + TestFormStructure form_structure{form_data}; + + // Set the first |number_of_fields_for_testing| fields to be autocomplete + // fields. + for (size_t i = 0; i < number_of_fields_for_testing; i++) { + form_structure.set_server_field_type_for_testing(i, UNKNOWN_TYPE); + } + +#if !BUILDFLAG(IS_IOS) + // Set the next |number_of_fields_for_testing| fields to be merchant promo + // code fields. + for (size_t i = number_of_fields_for_testing; + i < number_of_fields_for_testing * 2; i++) { + form_structure.set_server_field_type_for_testing(i, MERCHANT_PROMO_CODE); + } +#endif // !BUILDFLAG(IS_IOS) + + std::vector<FormFieldData> submitted_autocomplete_fields; + bool autocomplete_fields_is_autocomplete_enabled = false; + EXPECT_CALL(*autocomplete_history_manager_, OnWillSubmitFormWithFields(_, _)) + .WillOnce( + (DoAll(SaveArg<0>(&submitted_autocomplete_fields), + SaveArg<1>(&autocomplete_fields_is_autocomplete_enabled)))); + +#if !BUILDFLAG(IS_IOS) + std::vector<FormFieldData> submitted_merchant_promo_code_fields; + bool merchant_promo_code_fields_is_autocomplete_enabled = false; + EXPECT_CALL(*merchant_promo_code_manager_, OnWillSubmitFormWithFields(_, _)) + .WillOnce((DoAll( + SaveArg<0>(&submitted_merchant_promo_code_fields), + SaveArg<1>(&merchant_promo_code_fields_is_autocomplete_enabled)))); +#endif // !BUILDFLAG(IS_IOS) single_field_form_fill_router_->OnWillSubmitForm( - FormData(), /*is_autocomplete_enabled=*/true); + form_data, &form_structure, /*is_autocomplete_enabled=*/true); + + EXPECT_TRUE(submitted_autocomplete_fields.size() == + number_of_fields_for_testing); + EXPECT_TRUE(autocomplete_fields_is_autocomplete_enabled); + +#if !BUILDFLAG(IS_IOS) + EXPECT_TRUE(submitted_merchant_promo_code_fields.size() == + number_of_fields_for_testing); + EXPECT_TRUE(merchant_promo_code_fields_is_autocomplete_enabled); +#endif // !BUILDFLAG(IS_IOS) } -// Ensure that the router routes to AutocompleteHistoryManager for this +// Ensure that the router routes to SingleFieldFormFillers for this // CancelPendingQueries call. TEST_F(SingleFieldFormFillRouterTest, - RouteToAutocompleteHistoryManager_CancelPendingQueries) { + RouteToAllSingleFieldFormFillers_CancelPendingQueries) { auto suggestions_handler = std::make_unique<MockSuggestionsHandler>(); EXPECT_CALL(*autocomplete_history_manager_, CancelPendingQueries); +#if !BUILDFLAG(IS_IOS) + EXPECT_CALL(*merchant_promo_code_manager_, CancelPendingQueries); +#endif // !BUILDFLAG(IS_IOS) + single_field_form_fill_router_->CancelPendingQueries( suggestions_handler.get()); } @@ -122,4 +204,57 @@ /*value=*/u"Value", POPUP_ITEM_ID_AUTOCOMPLETE_ENTRY); } +#if !BUILDFLAG(IS_IOS) +// Ensure that the router routes to MerchantPromoCodeManager for this +// OnGetSingleFieldSuggestions call. +TEST_F(SingleFieldFormFillRouterTest, + RouteToMerchantPromoCodeManager_OnGetSingleFieldSuggestions) { + base::test::ScopedFeatureList feature_list; + feature_list.InitWithFeatures({features::kAutofillFillMerchantPromoCodeFields, + features::kAutofillServerTypeTakesPrecedence}, + {}); + auto suggestions_handler = std::make_unique<MockSuggestionsHandler>(); + + EXPECT_CALL(*merchant_promo_code_manager_, OnGetSingleFieldSuggestions); + std::vector<FieldPrediction> merchant_promo_code_field_predictions; + FieldPrediction merchant_promo_code_field_prediction; + merchant_promo_code_field_prediction.set_type(MERCHANT_PROMO_CODE); + merchant_promo_code_field_predictions.push_back( + merchant_promo_code_field_prediction); + SuggestionsContext context; + AutofillField autofill_field; + autofill_field.set_server_predictions( + std::move(merchant_promo_code_field_predictions)); + context.focused_field = &autofill_field; + single_field_form_fill_router_->OnGetSingleFieldSuggestions( + /*query_id=*/2, /*is_autocomplete_enabled=*/true, + /*autoselect_first_suggestion=*/false, /*name=*/u"Some Field Name", + /*prefix=*/u"SomePrefix", + /*form_control_type=*/"SomeType", suggestions_handler->GetWeakPtr(), + context); +} + +// Ensure that the router routes to MerchantPromoCodeManager for this +// OnRemoveCurrentSingleFieldSuggestion call. +TEST_F(SingleFieldFormFillRouterTest, + RouteToMerchantPromoCodeManager_OnRemoveCurrentSingleFieldSuggestion) { + EXPECT_CALL(*merchant_promo_code_manager_, + OnRemoveCurrentSingleFieldSuggestion); + + single_field_form_fill_router_->OnRemoveCurrentSingleFieldSuggestion( + /*field_name=*/u"Field Name", /*value=*/u"Value", + POPUP_ITEM_ID_MERCHANT_PROMO_CODE_ENTRY); +} + +// Ensure that the router routes to MerchantPromoCodeManager for this +// OnSingleFieldSuggestionSelected call. +TEST_F(SingleFieldFormFillRouterTest, + RouteToMerchantPromoCodeManager_OnSingleFieldSuggestionSelected) { + EXPECT_CALL(*merchant_promo_code_manager_, OnSingleFieldSuggestionSelected); + + single_field_form_fill_router_->OnSingleFieldSuggestionSelected( + /*value=*/u"Value", POPUP_ITEM_ID_MERCHANT_PROMO_CODE_ENTRY); +} +#endif // !BUILDFLAG(IS_IOS) + } // namespace autofill
diff --git a/components/autofill/core/browser/single_field_form_filler.h b/components/autofill/core/browser/single_field_form_filler.h index 4131c26..41414ab 100644 --- a/components/autofill/core/browser/single_field_form_filler.h +++ b/components/autofill/core/browser/single_field_form_filler.h
@@ -64,7 +64,7 @@ // saves the given |fields| that are eligible to be saved as new or updated // Autocomplete entries, which can then be served in the future as // suggestions. This update is dependent on whether we are running in - // incognito and if Autocomplete is enabled or not. + // incognito and if Autocomplete is enabled or not. |fields| can be empty. virtual void OnWillSubmitFormWithFields( const std::vector<FormFieldData>& fields, bool is_autocomplete_enabled) = 0;
diff --git a/components/autofill/core/browser/ui/popup_item_ids.h b/components/autofill/core/browser/ui/popup_item_ids.h index d110d03..45dd91dd 100644 --- a/components/autofill/core/browser/ui/popup_item_ids.h +++ b/components/autofill/core/browser/ui/popup_item_ids.h
@@ -48,7 +48,8 @@ POPUP_ITEM_ID_USERNAME_ENTRY, POPUP_ITEM_ID_ACCOUNT_STORAGE_PASSWORD_ENTRY, POPUP_ITEM_ID_ACCOUNT_STORAGE_USERNAME_ENTRY, - POPUP_ITEM_ID_VIRTUAL_CREDIT_CARD_ENTRY}; + POPUP_ITEM_ID_VIRTUAL_CREDIT_CARD_ENTRY, + POPUP_ITEM_ID_MERCHANT_PROMO_CODE_ENTRY}; } // namespace autofill
diff --git a/components/commerce_strings.grdp b/components/commerce_strings.grdp index 0c9847f..257c88d7 100644 --- a/components/commerce_strings.grdp +++ b/components/commerce_strings.grdp
@@ -23,6 +23,9 @@ <message name="IDS_DISCOUNT_CONTEXTUAL_CONSENT_ACCEPTED_CONFIRMATION_DONE" desc="The text shown on the consent confirmation bubble bubble. User can click this button to close the bubble."> Done </message> + <message name="IDS_NATIVE_NTP_CART_DISCOUNT_CONSENT_ACCEPT_BUTTON" translateable="false" desc="Text shown on the accept button of the consent dialog. By clicking this button, users has chosen to accept the consent."> + Yes, I'm in + </message> <message name="IDS_NATIVE_NTP_CART_DISCOUNT_CONSENT_TITLE" translateable="false" desc="The title shown on the native consent dialog for getting discount. Note: This is intentionally left blank. No translation is required."> Get discounts on your carts & when you shop online </message>
diff --git a/components/exo/pointer.cc b/components/exo/pointer.cc index e0bcb66..008a4ef7 100644 --- a/components/exo/pointer.cc +++ b/components/exo/pointer.cc
@@ -478,7 +478,7 @@ // ui::EventHandler overrides: void Pointer::OnMouseEvent(ui::MouseEvent* event) { - if (seat_->was_shutdown()) + if (seat_->was_shutdown() || event->handled()) return; // Nothing to report to a client nor have to update the pointer when capture
diff --git a/components/exo/pointer_unittest.cc b/components/exo/pointer_unittest.cc index ef0c283..9cd3d422 100644 --- a/components/exo/pointer_unittest.cc +++ b/components/exo/pointer_unittest.cc
@@ -1176,6 +1176,45 @@ pointer.reset(); } +TEST_F(PointerTest, IgnoresHandledEvents) { + // A very dumb handler that simply marks all events as handled. This is needed + // allows us to mark a mouse event as handled as it gets processed by the + // event processor. + class SetHandledHandler : public ui::EventHandler { + void OnMouseEvent(ui::MouseEvent* event) override { event->SetHandled(); } + }; + SetHandledHandler handler; + ash::Shell::Get()->AddPreTargetHandler(&handler); + + Seat seat(std::make_unique<TestDataExchangeDelegate>()); + testing::NiceMock<MockPointerDelegate> pointer_delegate; + std::unique_ptr<Pointer> pointer(new Pointer(&pointer_delegate, &seat)); + + // Make origin into a real window so the touch can click it + std::unique_ptr<ShellSurface> shell_surface = + test::ShellSurfaceBuilder({10, 10}).BuildShellSurface(); + + EXPECT_CALL(pointer_delegate, CanAcceptPointerEventsForSurface(testing::_)) + .WillRepeatedly(testing::Return(true)); + ui::test::EventGenerator generator(ash::Shell::GetPrimaryRootWindow()); + + // The SetHandlerHandler should have marked the event as processed. Therefore + // the event should simply be ignored. + EXPECT_CALL(pointer_delegate, + OnPointerButton(testing::_, testing::_, testing::_)) + .Times(0); + + // This event should be ignored because it has already been handled. + auto window_point = shell_surface->surface_for_testing() + ->window() + ->GetBoundsInScreen() + .CenterPoint(); + generator.MoveMouseTo(window_point); + generator.ClickLeftButton(); + + ash::Shell::Get()->RemovePreTargetHandler(&handler); +} + namespace { class PointerDragDropObserver : public WMHelper::DragDropObserver {
diff --git a/components/exo/shell_surface.cc b/components/exo/shell_surface.cc index b17cf4d..62035e6 100644 --- a/components/exo/shell_surface.cc +++ b/components/exo/shell_surface.cc
@@ -602,9 +602,12 @@ GetClientBoundsInScreen(widget_), window_state->GetStateType(), IsResizing(), widget_->IsActive(), origin_offset); } else { - serial = configure_callback_.Run(gfx::Rect(), - chromeos::WindowStateType::kNormal, - false, false, origin_offset); + gfx::Rect bounds; + if (initial_bounds_) + bounds.set_origin(initial_bounds_->origin()); + serial = + configure_callback_.Run(bounds, chromeos::WindowStateType::kNormal, + false, false, origin_offset); } }
diff --git a/components/exo/text_input.cc b/components/exo/text_input.cc index 2117c5e..db5c088 100644 --- a/components/exo/text_input.cc +++ b/components/exo/text_input.cc
@@ -360,7 +360,8 @@ return false; } -absl::optional<ui::GrammarFragment> TextInput::GetGrammarFragmentAtCursor() { +absl::optional<ui::GrammarFragment> TextInput::GetGrammarFragmentAtCursor() + const { // TODO(https://crbug.com/1201454): Implement this method. NOTIMPLEMENTED_LOG_ONCE(); return absl::nullopt;
diff --git a/components/exo/text_input.h b/components/exo/text_input.h index 3b97f00..f31d7b8 100644 --- a/components/exo/text_input.h +++ b/components/exo/text_input.h
@@ -167,7 +167,8 @@ gfx::Range GetAutocorrectRange() const override; gfx::Rect GetAutocorrectCharacterBounds() const override; bool SetAutocorrectRange(const gfx::Range& range) override; - absl::optional<ui::GrammarFragment> GetGrammarFragmentAtCursor() override; + absl::optional<ui::GrammarFragment> GetGrammarFragmentAtCursor() + const override; bool ClearGrammarFragments(const gfx::Range& range) override; bool AddGrammarFragments( const std::vector<ui::GrammarFragment>& fragments) override;
diff --git a/components/exo/touch.cc b/components/exo/touch.cc index 2991da66..298bcad 100644 --- a/components/exo/touch.cc +++ b/components/exo/touch.cc
@@ -62,7 +62,7 @@ // ui::EventHandler overrides: void Touch::OnTouchEvent(ui::TouchEvent* event) { - if (seat_->was_shutdown()) + if (seat_->was_shutdown() || event->handled()) return; bool send_details = false;
diff --git a/components/exo/touch_unittest.cc b/components/exo/touch_unittest.cc index a5692111..b9cd717 100644 --- a/components/exo/touch_unittest.cc +++ b/components/exo/touch_unittest.cc
@@ -628,5 +628,38 @@ touch.reset(); } +TEST_F(TouchTest, IgnoresHandledEvents) { + // A very dumb handler that simply marks all events as handled. This is needed + // allows us to mark a mouse event as handled as it gets processed by the + // event processor. + class SetHandledHandler : public ui::EventHandler { + void OnTouchEvent(ui::TouchEvent* event) override { event->SetHandled(); } + }; + SetHandledHandler handler; + ash::Shell::Get()->AddPreTargetHandler(&handler); + + Seat seat(std::make_unique<TestDataExchangeDelegate>()); + + testing::NiceMock<MockTouchDelegate> touch_delegate; + std::unique_ptr<Touch> touch(new Touch(&touch_delegate, &seat)); + + // Make origin into a real window so the touch can click it + std::unique_ptr<ShellSurface> shell_surface = + test::ShellSurfaceBuilder({10, 10}).BuildShellSurface(); + + ui::test::EventGenerator generator(ash::Shell::GetPrimaryRootWindow()); + + // The SetHandlerHandler should have marked the event as processed. Therefore + // the event should simply be ignored. + EXPECT_CALL(touch_delegate, OnTouchFrame()).Times(0); + + generator.GestureTapAt(shell_surface->surface_for_testing() + ->window() + ->GetBoundsInScreen() + .CenterPoint()); + + ash::Shell::Get()->RemovePreTargetHandler(&handler); +} + } // namespace } // namespace exo
diff --git a/components/fuchsia_legacymetrics/BUILD.gn b/components/fuchsia_legacymetrics/BUILD.gn new file mode 100644 index 0000000..7ce18eb --- /dev/null +++ b/components/fuchsia_legacymetrics/BUILD.gn
@@ -0,0 +1,43 @@ +# Copyright 2022 The Chromium Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. + +assert(is_fuchsia) + +visibility = [] + +source_set("fuchsia_legacymetrics") { + visibility += [ + ":unit_tests", + "//chromecast/internal/*", + "//fuchsia/engine/*", + ] + sources = [ + "legacymetrics_client.cc", + "legacymetrics_histogram_flattener.cc", + "legacymetrics_histogram_flattener.h", + "legacymetrics_user_event_recorder.cc", + "legacymetrics_user_event_recorder.h", + ] + public = [ "legacymetrics_client.h" ] + deps = [ "//base" ] + public_deps = [ "//third_party/fuchsia-sdk/sdk/fidl/fuchsia.legacymetrics" ] + friend = [ ":unit_tests" ] # For access to private headers. +} + +source_set("unit_tests") { + testonly = true + visibility += [ "//components:components_unittests__exec" ] + sources = [ + "legacymetrics_client_unittest.cc", + "legacymetrics_histogram_flattener_unittest.cc", + "legacymetrics_user_event_recorder_unittest.cc", + ] + public_deps = [ + ":fuchsia_legacymetrics", + "//base", + "//base/test:test_support", + "//testing/gmock", + "//testing/gtest", + ] +}
diff --git a/components/fuchsia_legacymetrics/DIR_METADATA b/components/fuchsia_legacymetrics/DIR_METADATA new file mode 100644 index 0000000..210aa6a --- /dev/null +++ b/components/fuchsia_legacymetrics/DIR_METADATA
@@ -0,0 +1 @@ +mixins: "//build/fuchsia/COMMON_METADATA"
diff --git a/components/fuchsia_legacymetrics/OWNERS b/components/fuchsia_legacymetrics/OWNERS new file mode 100644 index 0000000..fd58a1c25 --- /dev/null +++ b/components/fuchsia_legacymetrics/OWNERS
@@ -0,0 +1,2 @@ +kmarshall@chromium.org +file://build/fuchsia/OWNERS
diff --git a/fuchsia/base/legacymetrics_client.cc b/components/fuchsia_legacymetrics/legacymetrics_client.cc similarity index 97% rename from fuchsia/base/legacymetrics_client.cc rename to components/fuchsia_legacymetrics/legacymetrics_client.cc index 6ac802d..f72dc9c 100644 --- a/fuchsia/base/legacymetrics_client.cc +++ b/components/fuchsia_legacymetrics/legacymetrics_client.cc
@@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#include "fuchsia/base/legacymetrics_client.h" +#include "components/fuchsia_legacymetrics/legacymetrics_client.h" #include <lib/fit/function.h> #include <lib/sys/cpp/component_context.h> @@ -19,9 +19,9 @@ #include "base/logging.h" #include "base/threading/thread_task_runner_handle.h" #include "base/time/time.h" -#include "fuchsia/base/legacymetrics_histogram_flattener.h" +#include "components/fuchsia_legacymetrics/legacymetrics_histogram_flattener.h" -namespace cr_fuchsia { +namespace fuchsia_legacymetrics { constexpr size_t LegacyMetricsClient::kMaxBatchSize; @@ -296,4 +296,4 @@ CompleteFlush(); } -} // namespace cr_fuchsia +} // namespace fuchsia_legacymetrics
diff --git a/fuchsia/base/legacymetrics_client.h b/components/fuchsia_legacymetrics/legacymetrics_client.h similarity index 92% rename from fuchsia/base/legacymetrics_client.h rename to components/fuchsia_legacymetrics/legacymetrics_client.h index b222b94..0f193ada 100644 --- a/fuchsia/base/legacymetrics_client.h +++ b/components/fuchsia_legacymetrics/legacymetrics_client.h
@@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#ifndef FUCHSIA_BASE_LEGACYMETRICS_CLIENT_H_ -#define FUCHSIA_BASE_LEGACYMETRICS_CLIENT_H_ +#ifndef COMPONENTS_FUCHSIA_LEGACYMETRICS_LEGACYMETRICS_CLIENT_H_ +#define COMPONENTS_FUCHSIA_LEGACYMETRICS_LEGACYMETRICS_CLIENT_H_ #include <fuchsia/legacymetrics/cpp/fidl.h> @@ -15,9 +15,9 @@ #include "base/sequence_checker.h" #include "base/time/time.h" #include "base/timer/timer.h" -#include "fuchsia/base/legacymetrics_user_event_recorder.h" +#include "components/fuchsia_legacymetrics/legacymetrics_user_event_recorder.h" -namespace cr_fuchsia { +namespace fuchsia_legacymetrics { // Used to report events & histogram data to the // fuchsia.legacymetrics.MetricsRecorder service. @@ -119,6 +119,6 @@ base::WeakPtrFactory<LegacyMetricsClient> weak_factory_{this}; }; -} // namespace cr_fuchsia +} // namespace fuchsia_legacymetrics -#endif // FUCHSIA_BASE_LEGACYMETRICS_CLIENT_H_ +#endif // COMPONENTS_FUCHSIA_LEGACYMETRICS_LEGACYMETRICS_CLIENT_H_
diff --git a/fuchsia/base/legacymetrics_client_unittest.cc b/components/fuchsia_legacymetrics/legacymetrics_client_unittest.cc similarity index 98% rename from fuchsia/base/legacymetrics_client_unittest.cc rename to components/fuchsia_legacymetrics/legacymetrics_client_unittest.cc index 3fdc4857..5ac72f4 100644 --- a/fuchsia/base/legacymetrics_client_unittest.cc +++ b/components/fuchsia_legacymetrics/legacymetrics_client_unittest.cc
@@ -18,12 +18,12 @@ #include "base/test/test_future.h" #include "base/threading/thread_task_runner_handle.h" #include "base/time/time.h" -#include "fuchsia/base/legacymetrics_client.h" -#include "fuchsia/base/legacymetrics_histogram_flattener.h" +#include "components/fuchsia_legacymetrics/legacymetrics_client.h" +#include "components/fuchsia_legacymetrics/legacymetrics_histogram_flattener.h" #include "testing/gmock/include/gmock/gmock.h" #include "testing/gtest/include/gtest/gtest.h" -namespace cr_fuchsia { +namespace fuchsia_legacymetrics { namespace { using ::testing::Property; @@ -704,4 +704,4 @@ } } // namespace -} // namespace cr_fuchsia +} // namespace fuchsia_legacymetrics
diff --git a/fuchsia/base/legacymetrics_histogram_flattener.cc b/components/fuchsia_legacymetrics/legacymetrics_histogram_flattener.cc similarity index 94% rename from fuchsia/base/legacymetrics_histogram_flattener.cc rename to components/fuchsia_legacymetrics/legacymetrics_histogram_flattener.cc index b6d7281..65a5a02 100644 --- a/fuchsia/base/legacymetrics_histogram_flattener.cc +++ b/components/fuchsia_legacymetrics/legacymetrics_histogram_flattener.cc
@@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#include "fuchsia/base/legacymetrics_histogram_flattener.h" +#include "components/fuchsia_legacymetrics/legacymetrics_histogram_flattener.h" #include <memory> #include <utility> @@ -11,7 +11,7 @@ #include "base/logging.h" #include "base/metrics/statistics_recorder.h" -namespace cr_fuchsia { +namespace fuchsia_legacymetrics { namespace { // Serializes changes to histogram metrics as FIDL structs. @@ -89,4 +89,4 @@ return LegacyMetricsHistogramFlattener().GetDeltas(); } -} // namespace cr_fuchsia +} // namespace fuchsia_legacymetrics
diff --git a/components/fuchsia_legacymetrics/legacymetrics_histogram_flattener.h b/components/fuchsia_legacymetrics/legacymetrics_histogram_flattener.h new file mode 100644 index 0000000..773087f0 --- /dev/null +++ b/components/fuchsia_legacymetrics/legacymetrics_histogram_flattener.h
@@ -0,0 +1,20 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef COMPONENTS_FUCHSIA_LEGACYMETRICS_LEGACYMETRICS_HISTOGRAM_FLATTENER_H_ +#define COMPONENTS_FUCHSIA_LEGACYMETRICS_LEGACYMETRICS_HISTOGRAM_FLATTENER_H_ + +#include <fuchsia/legacymetrics/cpp/fidl.h> +#include <vector> + +#include "base/metrics/histogram_flattener.h" +#include "base/metrics/histogram_snapshot_manager.h" + +namespace fuchsia_legacymetrics { + +std::vector<fuchsia::legacymetrics::Histogram> GetLegacyMetricsDeltas(); + +} // namespace fuchsia_legacymetrics + +#endif // COMPONENTS_FUCHSIA_LEGACYMETRICS_LEGACYMETRICS_HISTOGRAM_FLATTENER_H_
diff --git a/fuchsia/base/legacymetrics_histogram_flattener_unittest.cc b/components/fuchsia_legacymetrics/legacymetrics_histogram_flattener_unittest.cc similarity index 96% rename from fuchsia/base/legacymetrics_histogram_flattener_unittest.cc rename to components/fuchsia_legacymetrics/legacymetrics_histogram_flattener_unittest.cc index ad4bfc5..0a92d55 100644 --- a/fuchsia/base/legacymetrics_histogram_flattener_unittest.cc +++ b/components/fuchsia_legacymetrics/legacymetrics_histogram_flattener_unittest.cc
@@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#include "fuchsia/base/legacymetrics_histogram_flattener.h" +#include "components/fuchsia_legacymetrics/legacymetrics_histogram_flattener.h" #include "base/metrics/histogram_macros.h" #include "testing/gtest/include/gtest/gtest.h" @@ -10,7 +10,7 @@ using fuchsia::legacymetrics::Histogram; using fuchsia::legacymetrics::HistogramBucket; -namespace cr_fuchsia { +namespace fuchsia_legacymetrics { namespace { constexpr char kHistogramCount1M[] = "Foo.Bar"; @@ -164,4 +164,4 @@ } } // namespace -} // namespace cr_fuchsia +} // namespace fuchsia_legacymetrics
diff --git a/fuchsia/base/legacymetrics_user_event_recorder.cc b/components/fuchsia_legacymetrics/legacymetrics_user_event_recorder.cc similarity index 89% rename from fuchsia/base/legacymetrics_user_event_recorder.cc rename to components/fuchsia_legacymetrics/legacymetrics_user_event_recorder.cc index f762c51..d33af51 100644 --- a/fuchsia/base/legacymetrics_user_event_recorder.cc +++ b/components/fuchsia_legacymetrics/legacymetrics_user_event_recorder.cc
@@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#include "fuchsia/base/legacymetrics_user_event_recorder.h" +#include "components/fuchsia_legacymetrics/legacymetrics_user_event_recorder.h" #include <utility> @@ -10,7 +10,7 @@ #include "base/metrics/user_metrics.h" #include "base/time/time.h" -namespace cr_fuchsia { +namespace fuchsia_legacymetrics { constexpr size_t LegacyMetricsUserActionRecorder::kMaxEventCount; @@ -45,4 +45,4 @@ events_.push_back(std::move(fidl_event)); } -} // namespace cr_fuchsia +} // namespace fuchsia_legacymetrics
diff --git a/fuchsia/base/legacymetrics_user_event_recorder.h b/components/fuchsia_legacymetrics/legacymetrics_user_event_recorder.h similarity index 78% rename from fuchsia/base/legacymetrics_user_event_recorder.h rename to components/fuchsia_legacymetrics/legacymetrics_user_event_recorder.h index fbfff2e..0ab7d40 100644 --- a/fuchsia/base/legacymetrics_user_event_recorder.h +++ b/components/fuchsia_legacymetrics/legacymetrics_user_event_recorder.h
@@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#ifndef FUCHSIA_BASE_LEGACYMETRICS_USER_EVENT_RECORDER_H_ -#define FUCHSIA_BASE_LEGACYMETRICS_USER_EVENT_RECORDER_H_ +#ifndef COMPONENTS_FUCHSIA_LEGACYMETRICS_LEGACYMETRICS_USER_EVENT_RECORDER_H_ +#define COMPONENTS_FUCHSIA_LEGACYMETRICS_LEGACYMETRICS_USER_EVENT_RECORDER_H_ #include <fuchsia/legacymetrics/cpp/fidl.h> #include <string> @@ -11,7 +11,7 @@ #include "base/metrics/user_metrics.h" -namespace cr_fuchsia { +namespace fuchsia_legacymetrics { // Captures and stores user action events, and converts them to // fuchsia.legacymetrics equivalent. @@ -39,6 +39,6 @@ const base::ActionCallback on_event_callback_; }; -} // namespace cr_fuchsia +} // namespace fuchsia_legacymetrics -#endif // FUCHSIA_BASE_LEGACYMETRICS_USER_EVENT_RECORDER_H_ +#endif // COMPONENTS_FUCHSIA_LEGACYMETRICS_LEGACYMETRICS_USER_EVENT_RECORDER_H_
diff --git a/fuchsia/base/legacymetrics_user_event_recorder_unittest.cc b/components/fuchsia_legacymetrics/legacymetrics_user_event_recorder_unittest.cc similarity index 94% rename from fuchsia/base/legacymetrics_user_event_recorder_unittest.cc rename to components/fuchsia_legacymetrics/legacymetrics_user_event_recorder_unittest.cc index 61af5d19..8fa8174 100644 --- a/fuchsia/base/legacymetrics_user_event_recorder_unittest.cc +++ b/components/fuchsia_legacymetrics/legacymetrics_user_event_recorder_unittest.cc
@@ -2,13 +2,13 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#include "fuchsia/base/legacymetrics_user_event_recorder.h" +#include "components/fuchsia_legacymetrics/legacymetrics_user_event_recorder.h" #include "base/test/task_environment.h" #include "base/threading/thread_task_runner_handle.h" #include "testing/gtest/include/gtest/gtest.h" -namespace cr_fuchsia { +namespace fuchsia_legacymetrics { namespace { class LegacyMetricsUserActionRecorderTest : public testing::Test { @@ -87,4 +87,4 @@ } } // namespace -} // namespace cr_fuchsia +} // namespace fuchsia_legacymetrics
diff --git a/components/page_load_metrics/browser/observers/use_counter/ukm_features.cc b/components/page_load_metrics/browser/observers/use_counter/ukm_features.cc index beec227..42388f0 100644 --- a/components/page_load_metrics/browser/observers/use_counter/ukm_features.cc +++ b/components/page_load_metrics/browser/observers/use_counter/ukm_features.cc
@@ -227,6 +227,11 @@ WebFeature::kCookieHasNotBeenRefreshedIn301To350Days, WebFeature::kCookieHasNotBeenRefreshedIn351To400Days, WebFeature::kPartitionedCookies, + WebFeature::kScriptSchedulingType_Defer, + WebFeature::kScriptSchedulingType_ParserBlocking, + WebFeature::kScriptSchedulingType_ParserBlockingInline, + WebFeature::kScriptSchedulingType_InOrder, + WebFeature::kScriptSchedulingType_Async, })); return *opt_in_features; }
diff --git a/components/performance_manager/BUILD.gn b/components/performance_manager/BUILD.gn index c2779e5..2d58f6cb 100644 --- a/components/performance_manager/BUILD.gn +++ b/components/performance_manager/BUILD.gn
@@ -22,7 +22,6 @@ "decorators/process_hosted_content_types_aggregator.cc", "decorators/process_hosted_content_types_aggregator.h", "decorators/process_metrics_decorator.cc", - "decorators/tab_properties_decorator.cc", "embedder/binders.h", "embedder/graph_features.h", "embedder/performance_manager_lifetime.h", @@ -116,7 +115,6 @@ "public/decorators/page_live_state_decorator.h", "public/decorators/page_load_tracker_decorator_helper.h", "public/decorators/process_metrics_decorator.h", - "public/decorators/tab_properties_decorator.h", "public/execution_context/execution_context.h", "public/execution_context/execution_context_attached_data.h", "public/execution_context/execution_context_registry.h", @@ -260,7 +258,6 @@ "decorators/page_live_state_decorator_unittest.cc", "decorators/page_load_tracker_decorator_unittest.cc", "decorators/process_hosted_content_types_aggregator_unittest.cc", - "decorators/tab_properties_decorator_unittest.cc", "execution_context/execution_context_attached_data_unittest.cc", "execution_context/execution_context_registry_impl_unittest.cc", "execution_context_priority/ad_frame_voter_unittest.cc",
diff --git a/components/performance_manager/decorators/tab_properties_decorator.cc b/components/performance_manager/decorators/tab_properties_decorator.cc deleted file mode 100644 index d629cb88..0000000 --- a/components/performance_manager/decorators/tab_properties_decorator.cc +++ /dev/null
@@ -1,95 +0,0 @@ -// Copyright 2020 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "components/performance_manager/public/decorators/tab_properties_decorator.h" - -#include "components/performance_manager/decorators/decorators_utils.h" -#include "components/performance_manager/graph/node_attached_data_impl.h" -#include "components/performance_manager/graph/page_node_impl.h" -#include "components/performance_manager/public/graph/node_data_describer_registry.h" -#include "components/performance_manager/public/performance_manager.h" -#include "content/public/browser/browser_thread.h" - -namespace performance_manager { - -namespace { - -class TabPropertiesDataImpl - : public TabPropertiesDecorator::Data, - public NodeAttachedDataImpl<TabPropertiesDataImpl> { - public: - struct Traits : public NodeAttachedDataInMap<PageNodeImpl> {}; - ~TabPropertiesDataImpl() override = default; - TabPropertiesDataImpl(const TabPropertiesDataImpl& other) = delete; - TabPropertiesDataImpl& operator=(const TabPropertiesDataImpl&) = delete; - - // TabPropertiesDecorator::Data implementation. - bool IsInTabStrip() const override { return is_tab_; } - - void set_is_tab(bool is_tab) { is_tab_ = is_tab; } - - private: - // Make the impl our friend so it can access the constructor and any - // storage providers. - friend class ::performance_manager::NodeAttachedDataImpl< - TabPropertiesDataImpl>; - - explicit TabPropertiesDataImpl(const PageNodeImpl* page_node) {} - - bool is_tab_ = false; -}; - -const char kDescriberName[] = "TabPropertiesDecorator"; - -} // namespace - -void TabPropertiesDecorator::SetIsTab(content::WebContents* contents, - bool is_tab) { - SetPropertyForWebContentsPageNode(contents, - &TabPropertiesDataImpl::set_is_tab, is_tab); -} - -void TabPropertiesDecorator::SetIsTabForTesting(PageNode* page_node, - bool is_tab) { - auto* data = - TabPropertiesDataImpl::GetOrCreate(PageNodeImpl::FromNode(page_node)); - DCHECK(data); - data->set_is_tab(is_tab); -} - -void TabPropertiesDecorator::OnPassedToGraph(Graph* graph) { - graph->GetNodeDataDescriberRegistry()->RegisterDescriber(this, - kDescriberName); -} - -void TabPropertiesDecorator::OnTakenFromGraph(Graph* graph) { - graph->GetNodeDataDescriberRegistry()->UnregisterDescriber(this); -} - -base::Value TabPropertiesDecorator::DescribePageNodeData( - const PageNode* node) const { - auto* data = TabPropertiesDecorator::Data::FromPageNode(node); - if (!data) - return base::Value(); - - base::Value ret(base::Value::Type::DICTIONARY); - ret.SetBoolKey("IsInTabStrip", data->IsInTabStrip()); - - return ret; -} - -TabPropertiesDecorator::Data::Data() = default; -TabPropertiesDecorator::Data::~Data() = default; - -const TabPropertiesDecorator::Data* TabPropertiesDecorator::Data::FromPageNode( - const PageNode* page_node) { - return TabPropertiesDataImpl::Get(PageNodeImpl::FromNode(page_node)); -} - -TabPropertiesDecorator::Data* -TabPropertiesDecorator::Data::GetOrCreateForTesting(const PageNode* page_node) { - return TabPropertiesDataImpl::GetOrCreate(PageNodeImpl::FromNode(page_node)); -} - -} // namespace performance_manager
diff --git a/components/performance_manager/decorators/tab_properties_decorator_unittest.cc b/components/performance_manager/decorators/tab_properties_decorator_unittest.cc deleted file mode 100644 index 38714507..0000000 --- a/components/performance_manager/decorators/tab_properties_decorator_unittest.cc +++ /dev/null
@@ -1,40 +0,0 @@ -// Copyright 2020 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "components/performance_manager/public/decorators/tab_properties_decorator.h" - -#include "components/performance_manager/test_support/decorators_utils.h" -#include "components/performance_manager/test_support/performance_manager_test_harness.h" -#include "content/public/browser/web_contents.h" -#include "testing/gtest/include/gtest/gtest.h" - -namespace performance_manager { - -class TabPropertiesDecoratorTest : public PerformanceManagerTestHarness { - public: - TabPropertiesDecoratorTest() = default; - ~TabPropertiesDecoratorTest() override = default; - TabPropertiesDecoratorTest(const TabPropertiesDecoratorTest& other) = delete; - TabPropertiesDecoratorTest& operator=(const TabPropertiesDecoratorTest&) = - delete; - - void SetUp() override { - PerformanceManagerTestHarness::SetUp(); - SetContents(CreateTestWebContents()); - } - - void TearDown() override { - DeleteContents(); - PerformanceManagerTestHarness::TearDown(); - } -}; - -TEST_F(TabPropertiesDecoratorTest, SetIsTab) { - testing::EndToEndBooleanPropertyTest( - web_contents(), &TabPropertiesDecorator::Data::GetOrCreateForTesting, - &TabPropertiesDecorator::Data::IsInTabStrip, - &TabPropertiesDecorator::SetIsTab); -} - -} // namespace performance_manager \ No newline at end of file
diff --git a/components/performance_manager/embedder/performance_manager_registry.h b/components/performance_manager/embedder/performance_manager_registry.h index 65d8250..cf4b04da9 100644 --- a/components/performance_manager/embedder/performance_manager_registry.h +++ b/components/performance_manager/embedder/performance_manager_registry.h
@@ -8,6 +8,7 @@ #include <memory> #include <vector> +#include "components/performance_manager/public/graph/page_node.h" #include "mojo/public/cpp/bindings/binder_map.h" #include "services/service_manager/public/cpp/binder_registry.h" @@ -61,6 +62,10 @@ virtual void CreatePageNodeForWebContents( content::WebContents* web_contents) = 0; + // Sets the page type for a WebContents. + virtual void SetPageType(content::WebContents* web_contents, + PageType type) = 0; + // Must be invoked for a NavigationHandle when it is committed, allowing the // PM the opportunity to apply NavigationThrottles. Typically wired up to // ContentBrowserClient::CreateThrottlesForNavigation.
diff --git a/components/performance_manager/graph/page_node.cc b/components/performance_manager/graph/page_node.cc index 81e198e..e9b65a6 100644 --- a/components/performance_manager/graph/page_node.cc +++ b/components/performance_manager/graph/page_node.cc
@@ -22,6 +22,19 @@ } // static +const char* PageNode::ToString(PageType type) { + switch (type) { + case PageType::kTab: + return "kTab"; + case PageType::kExtension: + return "kExtension"; + case PageType::kUnknown: + return "kUnknown"; + } + NOTREACHED(); +} + +// static const char* PageNode::ToString(PageNode::LoadingState loading_state) { switch (loading_state) { case LoadingState::kLoadingNotStarted:
diff --git a/components/performance_manager/graph/page_node_impl.cc b/components/performance_manager/graph/page_node_impl.cc index f4fed68..644e684b 100644 --- a/components/performance_manager/graph/page_node_impl.cc +++ b/components/performance_manager/graph/page_node_impl.cc
@@ -114,6 +114,11 @@ loading_state_.SetAndMaybeNotify(this, loading_state); } +void PageNodeImpl::SetType(PageType type) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + type_.SetAndMaybeNotify(this, type); +} + void PageNodeImpl::SetIsVisible(bool is_visible) { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); if (is_visible_.SetAndMaybeNotify(this, is_visible)) { @@ -216,6 +221,11 @@ return embedding_type_; } +PageType PageNodeImpl::type() const { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + return type_.value(); +} + bool PageNodeImpl::is_visible() const { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); return is_visible_.value(); @@ -453,6 +463,11 @@ return embedding_type(); } +PageType PageNodeImpl::GetType() const { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + return type(); +} + bool PageNodeImpl::IsVisible() const { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); return is_visible();
diff --git a/components/performance_manager/graph/page_node_impl.h b/components/performance_manager/graph/page_node_impl.h index be6f46dd..cc872517 100644 --- a/components/performance_manager/graph/page_node_impl.h +++ b/components/performance_manager/graph/page_node_impl.h
@@ -59,6 +59,7 @@ // dereferenced on the UI thread. const WebContentsProxy& contents_proxy() const; + void SetType(PageType type); void SetIsVisible(bool is_visible); void SetIsAudible(bool is_audible); void SetLoadingState(LoadingState loading_state); @@ -90,6 +91,7 @@ FrameNodeImpl* opener_frame_node() const; FrameNodeImpl* embedder_frame_node() const; EmbeddingType embedding_type() const; + PageType type() const; bool is_visible() const; bool is_audible() const; LoadingState loading_state() const; @@ -200,6 +202,7 @@ const FrameNode* GetOpenerFrameNode() const override; const FrameNode* GetEmbedderFrameNode() const override; EmbeddingType GetEmbeddingType() const override; + PageType GetType() const override; bool IsVisible() const override; base::TimeDelta GetTimeSinceLastVisibilityChange() const override; bool IsAudible() const override; @@ -293,6 +296,11 @@ EmbeddingType embedding_type_ GUARDED_BY_CONTEXT(sequence_checker_) = EmbeddingType::kInvalid; + // The type of the page. + ObservedProperty::NotifiesOnlyOnChanges<PageType, + &PageNodeObserver::OnTypeChanged> + type_ GUARDED_BY_CONTEXT(sequence_checker_){PageType::kUnknown}; + // Whether or not the page is visible. Driven by browser instrumentation. // Initialized on construction. ObservedProperty::NotifiesOnlyOnChanges<bool,
diff --git a/components/performance_manager/graph/page_node_impl_describer.cc b/components/performance_manager/graph/page_node_impl_describer.cc index 10a4a22..d470841 100644 --- a/components/performance_manager/graph/page_node_impl_describer.cc +++ b/components/performance_manager/graph/page_node_impl_describer.cc
@@ -69,6 +69,8 @@ page_node_impl->contents_mime_type_); result.SetStringKey("browser_context_id", page_node_impl->browser_context_id_); + result.SetStringKey("type", + PageNode::ToString(page_node_impl->type_.value())); result.SetBoolKey("is_visible", page_node_impl->is_visible_.value()); result.SetBoolKey("is_audible", page_node_impl->is_audible_.value()); result.SetStringKey(
diff --git a/components/performance_manager/graph/page_node_impl_unittest.cc b/components/performance_manager/graph/page_node_impl_unittest.cc index e4bc73d..ddc3963e 100644 --- a/components/performance_manager/graph/page_node_impl_unittest.cc +++ b/components/performance_manager/graph/page_node_impl_unittest.cc
@@ -227,6 +227,7 @@ void(const PageNode*, const FrameNode*)); MOCK_METHOD3(OnEmbedderFrameNodeChanged, void(const PageNode*, const FrameNode*, EmbeddingType)); + MOCK_METHOD1(OnTypeChanged, void(const PageNode*)); MOCK_METHOD1(OnIsVisibleChanged, void(const PageNode*)); MOCK_METHOD1(OnIsAudibleChanged, void(const PageNode*)); MOCK_METHOD2(OnLoadingStateChanged,
diff --git a/components/performance_manager/graph_features.cc b/components/performance_manager/graph_features.cc index 0194c98..ee55b08 100644 --- a/components/performance_manager/graph_features.cc +++ b/components/performance_manager/graph_features.cc
@@ -18,7 +18,6 @@ #include "components/performance_manager/graph/process_node_impl_describer.h" #include "components/performance_manager/graph/worker_node_impl_describer.h" #include "components/performance_manager/public/decorators/page_live_state_decorator.h" -#include "components/performance_manager/public/decorators/tab_properties_decorator.h" #include "components/performance_manager/public/graph/graph.h" #include "components/performance_manager/public/metrics/metrics_collector.h" #include "components/performance_manager/v8_memory/v8_context_tracker.h" @@ -60,8 +59,6 @@ Install<ProcessHostedContentTypesAggregator>(graph); if (flags_.process_node_impl_describer) Install<ProcessNodeImplDescriber>(graph); - if (flags_.tab_properties_decorator) - Install<TabPropertiesDecorator>(graph); if (flags_.worker_node_impl_describer) Install<WorkerNodeImplDescriber>(graph);
diff --git a/components/performance_manager/graph_features_unittest.cc b/components/performance_manager/graph_features_unittest.cc index 00776c7..5f0ec9da5 100644 --- a/components/performance_manager/graph_features_unittest.cc +++ b/components/performance_manager/graph_features_unittest.cc
@@ -53,7 +53,7 @@ execution_context::ExecutionContextRegistry::GetFromGraph(&graph)); EXPECT_FALSE(v8_memory::V8ContextTracker::GetFromGraph(&graph)); - size_t graph_owned_count = 13; + size_t graph_owned_count = 12; #if !BUILDFLAG(IS_ANDROID) // The SiteDataRecorder is not available on Android. graph_owned_count++; @@ -64,7 +64,7 @@ features.ConfigureGraph(&graph); EXPECT_EQ(graph_owned_count, graph.GraphOwnedCountForTesting()); EXPECT_EQ(3u, graph.GraphRegisteredCountForTesting()); - EXPECT_EQ(9u, graph.NodeDataDescriberCountForTesting()); + EXPECT_EQ(8u, graph.NodeDataDescriberCountForTesting()); // Ensure the GraphRegistered objects can be queried directly. EXPECT_TRUE( execution_context::ExecutionContextRegistry::GetFromGraph(&graph));
diff --git a/components/performance_manager/performance_manager_lifetime.cc b/components/performance_manager/performance_manager_lifetime.cc index 9224873..3a991d1c 100644 --- a/components/performance_manager/performance_manager_lifetime.cc +++ b/components/performance_manager/performance_manager_lifetime.cc
@@ -17,7 +17,6 @@ #include "components/performance_manager/graph/worker_node_impl_describer.h" #include "components/performance_manager/performance_manager_impl.h" #include "components/performance_manager/public/decorators/page_live_state_decorator.h" -#include "components/performance_manager/public/decorators/tab_properties_decorator.h" #include "components/performance_manager/public/graph/graph.h" #include "components/performance_manager/v8_memory/v8_context_tracker.h"
diff --git a/components/performance_manager/performance_manager_registry_impl.cc b/components/performance_manager/performance_manager_registry_impl.cc index 0533f7c..7939ebce 100644 --- a/components/performance_manager/performance_manager_registry_impl.cc +++ b/components/performance_manager/performance_manager_registry_impl.cc
@@ -9,6 +9,7 @@ #include "base/observer_list.h" #include "components/performance_manager/embedder/binders.h" +#include "components/performance_manager/graph/page_node_impl.h" #include "components/performance_manager/performance_manager_tab_helper.h" #include "components/performance_manager/public/mojom/coordination_unit.mojom.h" #include "components/performance_manager/public/performance_manager.h" @@ -134,6 +135,22 @@ observer.OnPageNodeCreatedForWebContents(web_contents); } +void PerformanceManagerRegistryImpl::SetPageType( + content::WebContents* web_contents, + PageType type) { + PerformanceManagerTabHelper* tab_helper = + PerformanceManagerTabHelper::FromWebContents(web_contents); + DCHECK(tab_helper); + + PerformanceManager::CallOnGraph( + FROM_HERE, + // Unretained() is safe because PerformanceManagerTabHelper owns the + // PageNodeImpl and deletes it by posting a task to the PerformanceManager + // sequence, which will be sequenced after the task posted here. + base::BindOnce(&PageNodeImpl::SetType, + base::Unretained(tab_helper->primary_page_node()), type)); +} + PerformanceManagerRegistryImpl::Throttles PerformanceManagerRegistryImpl::CreateThrottlesForNavigation( content::NavigationHandle* handle) {
diff --git a/components/performance_manager/performance_manager_registry_impl.h b/components/performance_manager/performance_manager_registry_impl.h index 309c702f..a45b63f 100644 --- a/components/performance_manager/performance_manager_registry_impl.h +++ b/components/performance_manager/performance_manager_registry_impl.h
@@ -75,6 +75,7 @@ // PerformanceManagerRegistry: void CreatePageNodeForWebContents( content::WebContents* web_contents) override; + void SetPageType(content::WebContents* web_contents, PageType type) override; Throttles CreateThrottlesForNavigation( content::NavigationHandle* handle) override; void NotifyBrowserContextAdded(
diff --git a/components/performance_manager/public/decorators/tab_properties_decorator.h b/components/performance_manager/public/decorators/tab_properties_decorator.h deleted file mode 100644 index 68bf0282..0000000 --- a/components/performance_manager/public/decorators/tab_properties_decorator.h +++ /dev/null
@@ -1,65 +0,0 @@ -// Copyright 2020 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef COMPONENTS_PERFORMANCE_MANAGER_PUBLIC_DECORATORS_TAB_PROPERTIES_DECORATOR_H_ -#define COMPONENTS_PERFORMANCE_MANAGER_PUBLIC_DECORATORS_TAB_PROPERTIES_DECORATOR_H_ - -#include "components/performance_manager/public/graph/graph.h" -#include "components/performance_manager/public/graph/node_data_describer.h" - -namespace content { -class WebContents; -} // namespace content - -namespace performance_manager { - -class PageNode; - -// The TabProperties decorator is responsible for tracking properties of -// PageNodes that are tabs. All the functions that take a WebContents* as a -// parameter should only be called from the UI thread, the event will be -// forwarded to the corresponding PageNode on the Performance Manager's -// sequence. -class TabPropertiesDecorator : public GraphOwned, - public NodeDataDescriberDefaultImpl { - public: - class Data; - - // This object should only be used via its static methods. - TabPropertiesDecorator() = default; - ~TabPropertiesDecorator() override = default; - TabPropertiesDecorator(const TabPropertiesDecorator& other) = delete; - TabPropertiesDecorator& operator=(const TabPropertiesDecorator&) = delete; - - // Set the is_tab property of a PageNode. - static void SetIsTab(content::WebContents* contents, bool is_tab); - - static void SetIsTabForTesting(PageNode* page_node, bool is_tab); - - private: - // GraphOwned implementation: - void OnPassedToGraph(Graph* graph) override; - void OnTakenFromGraph(Graph* graph) override; - - // NodeDataDescriber implementation: - base::Value DescribePageNodeData(const PageNode* node) const override; -}; - -class TabPropertiesDecorator::Data { - public: - Data(); - virtual ~Data(); - Data(const Data& other) = delete; - Data& operator=(const Data&) = delete; - - // Indicates if a PageNode belongs to a tab strip. - virtual bool IsInTabStrip() const = 0; - - static const Data* FromPageNode(const PageNode* page_node); - static Data* GetOrCreateForTesting(const PageNode* page_node); -}; - -} // namespace performance_manager - -#endif // COMPONENTS_PERFORMANCE_MANAGER_PUBLIC_DECORATORS_TAB_PROPERTIES_DECORATOR_H_
diff --git a/components/performance_manager/public/graph/page_node.h b/components/performance_manager/public/graph/page_node.h index 059e101..d15a2f9d 100644 --- a/components/performance_manager/public/graph/page_node.h +++ b/components/performance_manager/public/graph/page_node.h
@@ -25,6 +25,15 @@ class FrameNode; class PageNodeObserver; +enum class PageType { + // A browser tab. + kTab, + // An extension background page. + kExtension, + // Anything else. + kUnknown, +}; + // A PageNode represents the root of a FrameTree, or equivalently a WebContents. // These may correspond to normal tabs, WebViews, Portals, Chrome Apps or // Extensions. @@ -70,7 +79,8 @@ kLoadedIdle, }; - // Returns a string for a PageNode::LoadingState enumeration. + // Returns a string for an enumeration value. + static const char* ToString(PageType type); static const char* ToString(PageNode::LoadingState loading_state); // State of a page. Pages can be born in "kActive" or "kPrerendering" state. @@ -114,6 +124,9 @@ // an embedder. virtual EmbeddingType GetEmbeddingType() const = 0; + // Returns the type of the page. + virtual PageType GetType() const = 0; + // Returns true if this page is currently visible, false otherwise. // See PageNodeObserver::OnIsVisibleChanged. virtual bool IsVisible() const = 0; @@ -249,6 +262,9 @@ const FrameNode* previous_embedder, EmbeddingType previous_embedder_type) = 0; + // Invoked when the GetType property changes. + virtual void OnTypeChanged(const PageNode* page_node) = 0; + // Invoked when the IsVisible property changes. virtual void OnIsVisibleChanged(const PageNode* page_node) = 0; @@ -327,6 +343,7 @@ const PageNode* page_node, const FrameNode* previous_embedder, EmbeddingType previous_embedding_type) override {} + void OnTypeChanged(const PageNode* page_node) override {} void OnIsVisibleChanged(const PageNode* page_node) override {} void OnIsAudibleChanged(const PageNode* page_node) override {} void OnLoadingStateChanged(const PageNode* page_node,
diff --git a/components/remote_cocoa/app_shim/certificate_viewer.mm b/components/remote_cocoa/app_shim/certificate_viewer.mm index 0c12e8b..725df3315 100644 --- a/components/remote_cocoa/app_shim/certificate_viewer.mm +++ b/components/remote_cocoa/app_shim/certificate_viewer.mm
@@ -11,7 +11,7 @@ #include "base/mac/foundation_util.h" #include "base/mac/scoped_cftyperef.h" #include "base/notreached.h" -#include "net/cert/x509_util_ios_and_mac.h" +#include "net/cert/x509_util_apple.h" #include "net/cert/x509_util_mac.h" namespace remote_cocoa {
diff --git a/components/segmentation_platform/internal/BUILD.gn b/components/segmentation_platform/internal/BUILD.gn index 4157fc05..83d0f94 100644 --- a/components/segmentation_platform/internal/BUILD.gn +++ b/components/segmentation_platform/internal/BUILD.gn
@@ -99,8 +99,6 @@ "scheduler/model_execution_scheduler.h", "scheduler/model_execution_scheduler_impl.cc", "scheduler/model_execution_scheduler_impl.h", - "segment_id_convertor.cc", - "segment_id_convertor.h", "segmentation_platform_service_impl.cc", "segmentation_platform_service_impl.h", "segmentation_ukm_helper.cc", @@ -160,9 +158,6 @@ "//url:url", ] - public_deps = - [ "//components/optimization_guide/proto:optimization_guide_proto" ] - if (is_android) { sources += [ "android/segmentation_platform_service_android.cc", @@ -182,11 +177,14 @@ "execution/optimization_guide/optimization_guide_segmentation_model_provider.h", "execution/optimization_guide/segmentation_model_executor.cc", "execution/optimization_guide/segmentation_model_executor.h", + "segment_id_convertor.cc", + "segment_id_convertor.h", ] deps = [ ":internal", "//base", "//components/optimization_guide/core", + "//components/optimization_guide/proto:optimization_guide_proto", "//components/segmentation_platform/internal/proto", "//components/segmentation_platform/public", ] @@ -270,6 +268,7 @@ "//components/leveldb_proto:test_support", "//components/optimization_guide/core", "//components/optimization_guide/core:test_support", + "//components/optimization_guide/proto:optimization_guide_proto", "//components/prefs", "//components/prefs:test_support", "//components/segmentation_platform/internal/proto",
diff --git a/components/segmentation_platform/internal/android/java/src/org/chromium/components/segmentation_platform/SegmentationPlatformServiceImpl.java b/components/segmentation_platform/internal/android/java/src/org/chromium/components/segmentation_platform/SegmentationPlatformServiceImpl.java index 38f34d5d..2737e52a 100644 --- a/components/segmentation_platform/internal/android/java/src/org/chromium/components/segmentation_platform/SegmentationPlatformServiceImpl.java +++ b/components/segmentation_platform/internal/android/java/src/org/chromium/components/segmentation_platform/SegmentationPlatformServiceImpl.java
@@ -8,8 +8,7 @@ import org.chromium.base.annotations.CalledByNative; import org.chromium.base.annotations.JNINamespace; import org.chromium.base.annotations.NativeMethods; -import org.chromium.components.optimization_guide.proto.ModelsProto; -import org.chromium.components.optimization_guide.proto.ModelsProto.OptimizationTarget; +import org.chromium.components.segmentation_platform.proto.SegmentationProto.SegmentId; /** * Java side of the JNI bridge between SegmentationPlatformServiceImpl in Java @@ -49,8 +48,8 @@ @CalledByNative private static SegmentSelectionResult createSegmentSelectionResult( boolean isReady, int selectedSegment) { - OptimizationTarget segment = ModelsProto.OptimizationTarget.forNumber(selectedSegment); - if (segment == null) segment = OptimizationTarget.OPTIMIZATION_TARGET_UNKNOWN; + SegmentId segment = SegmentId.forNumber(selectedSegment); + if (segment == null) segment = SegmentId.OPTIMIZATION_TARGET_UNKNOWN; return new SegmentSelectionResult(isReady, segment); }
diff --git a/components/segmentation_platform/internal/android/segmentation_platform_service_android.cc b/components/segmentation_platform/internal/android/segmentation_platform_service_android.cc index d580aa7..bad5ce9 100644 --- a/components/segmentation_platform/internal/android/segmentation_platform_service_android.cc +++ b/components/segmentation_platform/internal/android/segmentation_platform_service_android.cc
@@ -26,7 +26,7 @@ const SegmentSelectionResult& result) { int selected_segment = result.segment.has_value() ? result.segment.value() - : OptimizationTarget::OPTIMIZATION_TARGET_UNKNOWN; + : proto::SegmentId::OPTIMIZATION_TARGET_UNKNOWN; return Java_SegmentationPlatformServiceImpl_createSegmentSelectionResult( env, result.is_ready, selected_segment); }
diff --git a/components/segmentation_platform/internal/segmentation_platform_service_impl_unittest.cc b/components/segmentation_platform/internal/segmentation_platform_service_impl_unittest.cc index 2387e054..92a66ef 100644 --- a/components/segmentation_platform/internal/segmentation_platform_service_impl_unittest.cc +++ b/components/segmentation_platform/internal/segmentation_platform_service_impl_unittest.cc
@@ -106,7 +106,7 @@ SegmentSelectionResult result; result.is_ready = is_ready; if (is_ready) - result.segment = SegmentIdToOptimizationTarget(expected); + result.segment = expected; base::RunLoop loop; segmentation_platform_service_impl_->GetSelectedSegment( segmentation_key, @@ -123,7 +123,7 @@ SegmentSelectionResult result; result.is_ready = is_ready; if (is_ready) - result.segment = SegmentIdToOptimizationTarget(expected); + result.segment = expected; ASSERT_EQ(result, segmentation_platform_service_impl_->GetCachedSegmentResult( segmentation_key)); @@ -256,7 +256,7 @@ GetSelectedSegmentBeforeInitialization) { SegmentSelectionResult expected; expected.is_ready = true; - expected.segment = OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE; + expected.segment = proto::SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE; base::RunLoop loop; segmentation_platform_service_impl_->GetSelectedSegment( kTestSegmentationKey1,
diff --git a/components/segmentation_platform/internal/selection/segment_selector_impl.cc b/components/segmentation_platform/internal/selection/segment_selector_impl.cc index 91bb3c4..074af992 100644 --- a/components/segmentation_platform/internal/selection/segment_selector_impl.cc +++ b/components/segmentation_platform/internal/selection/segment_selector_impl.cc
@@ -20,7 +20,6 @@ #include "components/segmentation_platform/internal/platform_options.h" #include "components/segmentation_platform/internal/proto/model_metadata.pb.h" #include "components/segmentation_platform/internal/proto/model_prediction.pb.h" -#include "components/segmentation_platform/internal/segment_id_convertor.h" #include "components/segmentation_platform/internal/selection/experimental_group_recorder.h" #include "components/segmentation_platform/internal/selection/segment_result_provider.h" #include "components/segmentation_platform/internal/selection/segmentation_result_prefs.h" @@ -112,8 +111,7 @@ stats::SegmentationKeyToTrialName(config_->segmentation_key); std::string group_name; if (selected_segment.has_value()) { - selected_segment_last_session_.segment = - SegmentIdToOptimizationTarget(selected_segment->segment_id); + selected_segment_last_session_.segment = selected_segment->segment_id; selected_segment_last_session_.is_ready = true; stats::RecordSegmentSelectionFailure( config_->segmentation_key, @@ -250,7 +248,7 @@ DCHECK(!callback.is_null()); SegmentSelectionResult result; result.is_ready = true; - result.segment = SegmentIdToOptimizationTarget(selected_segment); + result.segment = selected_segment; std::move(callback).Run(result); } else { DCHECK(callback.is_null());
diff --git a/components/segmentation_platform/internal/selection/segment_selector_unittest.cc b/components/segmentation_platform/internal/selection/segment_selector_unittest.cc index 2941717..e2b8b86 100644 --- a/components/segmentation_platform/internal/selection/segment_selector_unittest.cc +++ b/components/segmentation_platform/internal/selection/segment_selector_unittest.cc
@@ -15,7 +15,6 @@ #include "components/segmentation_platform/internal/execution/mock_model_provider.h" #include "components/segmentation_platform/internal/metadata/metadata_utils.h" #include "components/segmentation_platform/internal/metric_filter_utils.h" -#include "components/segmentation_platform/internal/segment_id_convertor.h" #include "components/segmentation_platform/internal/selection/segmentation_result_prefs.h" #include "components/segmentation_platform/public/config.h" #include "components/segmentation_platform/public/field_trial_register.h" @@ -212,8 +211,7 @@ base::BindOnce( [](base::OnceClosure quit, const SegmentSelectionResult& result) { EXPECT_TRUE(result.is_ready); - EXPECT_EQ(kSegmentId2, - OptimizationTargetToSegmentId(*result.segment)); + EXPECT_EQ(kSegmentId2, *result.segment); std::move(quit).Run(); }, wait_for_selection.QuitClosure())); @@ -375,7 +373,7 @@ segment_selector_->OnPlatformInitialized(nullptr); SegmentSelectionResult result; - result.segment = SegmentIdToOptimizationTarget(segment_id0); + result.segment = segment_id0; result.is_ready = true; GetSelectedSegment(result); ASSERT_EQ(result, segment_selector_->GetCachedSegmentResult());
diff --git a/components/segmentation_platform/internal/service_proxy_impl.cc b/components/segmentation_platform/internal/service_proxy_impl.cc index 2bcb0da..10f6e23 100644 --- a/components/segmentation_platform/internal/service_proxy_impl.cc +++ b/components/segmentation_platform/internal/service_proxy_impl.cc
@@ -14,7 +14,6 @@ #include "components/segmentation_platform/internal/database/signal_storage_config.h" #include "components/segmentation_platform/internal/metadata/metadata_utils.h" #include "components/segmentation_platform/internal/scheduler/execution_service.h" -#include "components/segmentation_platform/internal/segment_id_convertor.h" #include "components/segmentation_platform/internal/segmentation_platform_service_impl.h" #include "components/segmentation_platform/internal/selection/segment_selector_impl.h" #include "components/segmentation_platform/public/config.h" @@ -172,12 +171,12 @@ if (segment_selectors_ && segment_selectors_->find(config->segmentation_key) != segment_selectors_->end()) { - absl::optional<OptimizationTarget> target = + absl::optional<proto::SegmentId> target = segment_selectors_->at(config->segmentation_key) ->GetCachedSegmentResult() .segment; if (target.has_value()) { - selected = OptimizationTargetToSegmentId(*target); + selected = *target; } } result.emplace_back(config->segmentation_key, selected);
diff --git a/components/segmentation_platform/public/BUILD.gn b/components/segmentation_platform/public/BUILD.gn index 760107d..d5aa36e 100644 --- a/components/segmentation_platform/public/BUILD.gn +++ b/components/segmentation_platform/public/BUILD.gn
@@ -30,7 +30,6 @@ deps = [ "//base", "//components/keyed_service/core", - "//components/optimization_guide/proto:optimization_guide_proto", ] } @@ -56,8 +55,6 @@ deps = [ "//base:base_java" ] - public_deps = [ - "//components/optimization_guide/proto:optimization_guide_proto_java", - ] + public_deps = [ "//components/segmentation_platform/public/proto:segmentation_platform_proto_java" ] } }
diff --git a/components/segmentation_platform/public/android/java/src/org/chromium/components/segmentation_platform/SegmentSelectionResult.java b/components/segmentation_platform/public/android/java/src/org/chromium/components/segmentation_platform/SegmentSelectionResult.java index d4e121ba..7300503 100644 --- a/components/segmentation_platform/public/android/java/src/org/chromium/components/segmentation_platform/SegmentSelectionResult.java +++ b/components/segmentation_platform/public/android/java/src/org/chromium/components/segmentation_platform/SegmentSelectionResult.java
@@ -4,7 +4,7 @@ package org.chromium.components.segmentation_platform; -import org.chromium.components.optimization_guide.proto.ModelsProto.OptimizationTarget; +import org.chromium.components.segmentation_platform.proto.SegmentationProto.SegmentId; /** * Java counterpart of native SegmentSelectionResult. Contains the result of segment selection. @@ -16,10 +16,10 @@ /** * The result of segment selection. */ - public final OptimizationTarget selectedSegment; + public final SegmentId selectedSegment; /** Constructor */ - public SegmentSelectionResult(boolean isReady, OptimizationTarget selectedSegment) { + public SegmentSelectionResult(boolean isReady, SegmentId selectedSegment) { this.isReady = isReady; this.selectedSegment = selectedSegment; }
diff --git a/components/segmentation_platform/public/proto/BUILD.gn b/components/segmentation_platform/public/proto/BUILD.gn index 8118438..2ae31b38 100644 --- a/components/segmentation_platform/public/proto/BUILD.gn +++ b/components/segmentation_platform/public/proto/BUILD.gn
@@ -4,7 +4,18 @@ import("//third_party/protobuf/proto_library.gni") +if (is_android) { + import("//build/config/android/rules.gni") +} + proto_library("proto") { proto_in_dir = "//" sources = [ "segmentation_platform.proto" ] } + +if (is_android) { + proto_java_library("segmentation_platform_proto_java") { + proto_path = "//" + sources = [ "segmentation_platform.proto" ] + } +}
diff --git a/components/segmentation_platform/public/proto/segmentation_platform.proto b/components/segmentation_platform/public/proto/segmentation_platform.proto index 557a51da..c0ff147 100644 --- a/components/segmentation_platform/public/proto/segmentation_platform.proto +++ b/components/segmentation_platform/public/proto/segmentation_platform.proto
@@ -4,6 +4,8 @@ syntax = "proto2"; option optimize_for = LITE_RUNTIME; +option java_package = "org.chromium.components.segmentation_platform.proto"; +option java_outer_classname = "SegmentationProto"; package segmentation_platform.proto;
diff --git a/components/segmentation_platform/public/segment_selection_result.h b/components/segmentation_platform/public/segment_selection_result.h index 917ba61..70e223f 100644 --- a/components/segmentation_platform/public/segment_selection_result.h +++ b/components/segmentation_platform/public/segment_selection_result.h
@@ -5,11 +5,9 @@ #ifndef COMPONENTS_SEGMENTATION_PLATFORM_PUBLIC_SEGMENT_SELECTION_RESULT_H_ #define COMPONENTS_SEGMENTATION_PLATFORM_PUBLIC_SEGMENT_SELECTION_RESULT_H_ -#include "components/optimization_guide/proto/models.pb.h" +#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h" #include "third_party/abseil-cpp/absl/types/optional.h" -using optimization_guide::proto::OptimizationTarget; - namespace segmentation_platform { // The result of segmentation and related metadata. @@ -26,7 +24,7 @@ // The result of segmentation. Can be empty if the the backend couldn't select // a segment with confidence. - absl::optional<OptimizationTarget> segment; + absl::optional<proto::SegmentId> segment; }; } // namespace segmentation_platform
diff --git a/content/browser/client_hints/client_hints.cc b/content/browser/client_hints/client_hints.cc index b6dc7a9cb..c6a04c61 100644 --- a/content/browser/client_hints/client_hints.cc +++ b/content/browser/client_hints/client_hints.cc
@@ -458,9 +458,8 @@ } // Returns true iff the `url` is embedded inside a frame that has the -// corresponding Sec-CH-UA-Reduced, Sec-CH-UA-Full, or -// Sec-CH-Partitioned-Cookies client hint and thus, is enrolled in the -// UserAgentReduction, SendFullUserAgentAfterReduction, or PartitionedCookies +// corresponding Sec-CH-UA-Reduced or Sec-CH-UA-Full client hint and thus, is +// enrolled in the UserAgentReduction or SendFullUserAgentAfterReduction // Origin Trial. // // TODO(crbug.com/1258063): Remove when the UserAgentReduction and @@ -510,8 +509,7 @@ for (auto it = accept_ch->begin(); it != accept_ch->end();) { if (*it == WebClientHintsType::kUAReduced || - *it == WebClientHintsType::kFullUserAgent || - *it == WebClientHintsType::kPartitionedCookies) { + *it == WebClientHintsType::kFullUserAgent) { ++it; } else { it = accept_ch->erase(it); @@ -586,9 +584,6 @@ is_embedder_ua_full = IsOriginTrialHintEnabledForFrame( trial_origin, outermost_main_frame_origin, frame_tree_node, delegate, WebClientHintsType::kFullUserAgent); - is_embedder_partitioned_cookies = IsOriginTrialHintEnabledForFrame( - trial_origin, outermost_main_frame_origin, frame_tree_node, delegate, - WebClientHintsType::kPartitionedCookies); } // Record the time spent getting the client hints. @@ -613,12 +608,6 @@ // receive the full User-Agent header, so we want to also send the full // User-Agent for the embedded request as well. bool is_embedder_ua_full = false; - // If true, one of the ancestor requests in the path to this request had - // Sec-CH-Partitioned-Cookies in their Accept-CH cache. Only appplies to - // embedded requests (top-level requests will always set this to false). - // - // If the embedder of the - bool is_embedder_partitioned_cookies = false; url::Origin resource_origin; bool is_outermost_main_frame = false; url::Origin outermost_main_frame_origin; @@ -628,8 +617,7 @@ bool SkipPermissionPolicyCheck(WebClientHintsType type) { return type == WebClientHintsType::kUAReduced || - type == WebClientHintsType::kFullUserAgent || - type == WebClientHintsType::kPartitionedCookies; + type == WebClientHintsType::kFullUserAgent; } bool IsClientHintEnabled(const ClientHintsExtendedData& data, @@ -638,9 +626,7 @@ (type == WebClientHintsType::kUAReduced && data.is_embedder_ua_reduced) || (type == WebClientHintsType::kFullUserAgent && - data.is_embedder_ua_full) || - (type == WebClientHintsType::kPartitionedCookies && - data.is_embedder_partitioned_cookies); + data.is_embedder_ua_full); } bool IsClientHintAllowed(const ClientHintsExtendedData& data, @@ -949,11 +935,6 @@ AddPrefersColorSchemeHeader(headers, frame_tree_node); } - if (ShouldAddClientHint(data, WebClientHintsType::kPartitionedCookies)) { - SetHeaderToString(headers, WebClientHintsType::kPartitionedCookies, - SerializeHeaderString(true)); - } - if (ShouldAddClientHint(data, WebClientHintsType::kSaveData)) AddSaveDataHeader(headers, context); @@ -1092,15 +1073,6 @@ enabled_hints.GetEnabledHints(); PersistAcceptCH(origin, frame_tree_node->GetParentOrOuterDocument(), delegate, persisted_hints); - if (std::find(persisted_hints.begin(), persisted_hints.end(), - WebClientHintsType::kPartitionedCookies) == - persisted_hints.end()) { - if (auto* cookie_manager = frame_tree_node->current_frame_host() - ->GetStoragePartition() - ->GetCookieManagerForBrowserProcess()) { - cookie_manager->ConvertPartitionedCookiesToUnpartitioned(origin.GetURL()); - } - } return persisted_hints; } @@ -1133,10 +1105,6 @@ !base::Contains(hints, WebClientHintsType::kFullUserAgent)) { hints.push_back(WebClientHintsType::kFullUserAgent); } - if (data.is_embedder_partitioned_cookies && - !base::Contains(hints, WebClientHintsType::kPartitionedCookies)) { - hints.push_back(WebClientHintsType::kPartitionedCookies); - } return hints; }
diff --git a/content/browser/net/http_cookie_browsertest.cc b/content/browser/net/http_cookie_browsertest.cc index 4fea2384..eb1c217 100644 --- a/content/browser/net/http_cookie_browsertest.cc +++ b/content/browser/net/http_cookie_browsertest.cc
@@ -5,14 +5,18 @@ #include "base/strings/strcat.h" #include "base/strings/string_split.h" #include "base/strings/stringprintf.h" +#include "base/test/bind.h" #include "base/test/scoped_feature_list.h" #include "content/public/browser/browser_context.h" +#include "content/public/browser/storage_partition.h" #include "content/public/browser/web_contents.h" #include "content/public/test/browser_test.h" #include "content/public/test/browser_test_utils.h" #include "content/public/test/content_browser_test.h" +#include "content/public/test/content_browser_test_utils.h" #include "content/public/test/frame_test_utils.h" #include "content/public/test/test_navigation_observer.h" +#include "content/public/test/url_loader_interceptor.h" #include "content/shell/browser/shell.h" #include "net/base/features.h" #include "net/cookies/canonical_cookie_test_helpers.h" @@ -755,5 +759,551 @@ HttpCookieBrowserTest, ::testing::Bool()); +struct OriginTrialTestOptions { + bool has_ot_token = true; + bool valid_ot_token = true; + bool has_set_cookie = true; + bool has_partitioned = true; +}; + +// This class tests the origin trial mechanism for partitioned cookies. +// Partitioned cookies should be reverted to unpartitioned if the navigation +// has a Set-Cookie header with the Partitioned attribute and the site does +// not send a valid Origin-Trial header. +// This test exercises the origin trial for top-level navigation requests. +class PartitionedCookiesOriginTrialBrowserTest : public ContentBrowserTest { + protected: + void SetUp() override { + scoped_feature_list_.InitWithFeatures({net::features::kPartitionedCookies}, + {}); + ContentBrowserTest::SetUp(); + } + + void SetUpOnMainThread() override { + url_loader_interceptor_ = + std::make_unique<URLLoaderInterceptor>(base::BindRepeating( + &PartitionedCookiesOriginTrialBrowserTest::InterceptRequest, + base::Unretained(this))); + } + + void TearDownOnMainThread() override { + url_loader_interceptor_.reset(); + ContentBrowserTest::TearDownOnMainThread(); + } + + void SetTestOptions(const OriginTrialTestOptions& test_setting, + const std::set<GURL>& expected_request_urls) { + test_options_ = test_setting; + expected_request_urls_ = expected_request_urls; + } + + virtual const char* OriginTrialToken() const { + // The test Origin Trial token was generated by running: + // python tools/origin_trials/generate_token.py https://127.0.0.1:44444 \ + // PartitionedCookies \ + // --expire-timestamp=2000000000 + return "A4s/" + "iPKfhEfgqQIIuz4zLuCpONpXOuYyJFBhBx1MfgS1aNhFujyhsg4lkfTRfjzQCI3aUbM" + "wtNm25elLTR4UIgAAAABceyJvcmlnaW4iOiAiaHR0cHM6Ly8xMjcuMC4wLjE6NDQ0ND" + "QiLCAiZmVhdHVyZSI6ICJQYXJ0aXRpb25lZENvb2tpZXMiLCAiZXhwaXJ5IjogMjAwM" + "DAwMDAwMH0="; + } + + // We use URLLoaderInterceptor because we cannot control which port that + // EmbeddedTestServer uses. Since origin trials depend on the entire origin + // (including port) we need to intercept the requests using + // URLLoaderInterceptor. + bool InterceptRequest(URLLoaderInterceptor::RequestParams* params) { + if (expected_request_urls_.find(params->url_request.url) == + expected_request_urls_.end()) { + return false; + } + + std::string headers = "HTTP/1.1 200 OK\nContent-Type: text/html\n"; + std::string body = "<html><body>Hello world!</body></html>"; + if (test_options_.has_set_cookie) { + base::StrAppend( + &headers, + {"Set-Cookie: __Host-foo=bar; Secure; Path=/; SameSite=None;", + test_options_.has_partitioned ? " Partitioned" : "", "\n"}); + } + if (test_options_.has_ot_token) { + base::StrAppend( + &headers, + {"Origin-Trial: ", + test_options_.valid_ot_token ? OriginTrialToken() : "invalid", + "\n"}); + } + URLLoaderInterceptor::WriteResponse(headers, body, params->client.get(), + absl::nullopt, + /*url=*/params->url_request.url); + return true; + } + + network::mojom::CookieManager* GetCookieManager() { + return shell() + ->web_contents() + ->GetBrowserContext() + ->GetDefaultStoragePartition() + ->GetCookieManagerForBrowserProcess(); + } + + void SetCookie(const std::string& name, + const std::string& value, + const GURL& url, + const absl::optional<net::CookiePartitionKey>& partition_key) { + auto cookie = net::CanonicalCookie::CreateUnsafeCookieForTesting( + name, value, url.host(), "/", base::Time::Now() - base::Days(1), + base::Time::Now() + base::Days(1), base::Time::Now(), base::Time::Now(), + /*secure=*/true, /*httponly=*/false, + net::CookieSameSite::NO_RESTRICTION, + net::CookiePriority::COOKIE_PRIORITY_DEFAULT, /*same_party=*/false, + partition_key); + EXPECT_TRUE(cookie->IsCanonical()); + + base::RunLoop run_loop; + GetCookieManager()->SetCanonicalCookie( + *cookie, url, net::CookieOptions::MakeAllInclusive(), + base::BindLambdaForTesting( + [&](net::CookieAccessResult set_cookie_result) { + EXPECT_TRUE(set_cookie_result.status.IsInclude()); + run_loop.Quit(); + })); + run_loop.Run(); + } + + std::vector<net::CanonicalCookie> GetCookies(const GURL& url) { + std::vector<net::CanonicalCookie> cookies; + + base::RunLoop run_loop; + GetCookieManager()->GetCookieList( + url, net::CookieOptions::MakeAllInclusive(), + net::CookiePartitionKeyCollection::ContainsAll(), + base::BindLambdaForTesting( + [&](const std::vector<::net::CookieWithAccessResult>& result, + const std::vector<::net::CookieWithAccessResult>& + excluded_cookies) { + EXPECT_TRUE(excluded_cookies.empty()); + for (const auto& el : result) { + cookies.push_back(el.cookie); + } + run_loop.Quit(); + })); + run_loop.Run(); + + return cookies; + } + + const GURL CookieUrl() { return GURL("https://127.0.0.1:44444"); } + + void WaitForPage(const GURL& url) { + EXPECT_TRUE(NavigateToURL(shell(), url)); + WebContents* contents = shell()->web_contents(); + EXPECT_TRUE(WaitForLoadStop(contents)); + EXPECT_TRUE(WaitForRenderFrameReady(contents->GetMainFrame())); + } + + protected: + std::unique_ptr<URLLoaderInterceptor> url_loader_interceptor_; + OriginTrialTestOptions test_options_; + std::set<GURL> expected_request_urls_; + base::test::ScopedFeatureList scoped_feature_list_; +}; + +// Test that the partitioned cookie set before the request remains partitioned +// when the site sends a Set-Cookie header with the Partitioned attribute +// and a valid OT token. +IN_PROC_BROWSER_TEST_F(PartitionedCookiesOriginTrialBrowserTest, + ValidParticipant) { + SetCookie( + "__Host-foo", "bar", CookieUrl(), + net::CookiePartitionKey::FromURLForTesting(GURL("https://example.com"))); + SetTestOptions( + {/*has_ot_token=*/true, /*valid_ot_token=*/true, /*has_set_cookie=*/true, + /*has_partitioned=*/true}, + {CookieUrl()}); + + WaitForPage(CookieUrl()); + + auto cookies = GetCookies(CookieUrl()); + EXPECT_EQ(1u, cookies.size()); + EXPECT_TRUE(cookies[0].IsPartitioned()); +} + +// Test that the partitioned cookie is reverted to unpartitioned if the site +// sends a Set-Cookie with Partitioned but an invalid OT token. +IN_PROC_BROWSER_TEST_F(PartitionedCookiesOriginTrialBrowserTest, InvalidToken) { + SetCookie( + "__Host-foo", "bar", CookieUrl(), + net::CookiePartitionKey::FromURLForTesting(GURL("https://example.com"))); + SetTestOptions( + {/*has_ot_token=*/true, /*valid_ot_token=*/false, /*has_set_cookie=*/true, + /*has_partitioned=*/true}, + {CookieUrl()}); + + WaitForPage(CookieUrl()); + + auto cookies = GetCookies(CookieUrl()); + EXPECT_EQ(1u, cookies.size()); + EXPECT_FALSE(cookies[0].IsPartitioned()); +} + +// Test that the partitioned cookie is reverted to unpartitioned if the site +// sends a Set-Cookie with Partitioned but do not send an OT token. +IN_PROC_BROWSER_TEST_F(PartitionedCookiesOriginTrialBrowserTest, NoToken) { + SetCookie( + "__Host-foo", "bar", CookieUrl(), + net::CookiePartitionKey::FromURLForTesting(GURL("https://example.com"))); + SetTestOptions({/*has_ot_token=*/false, /*valid_ot_token=*/false, + /*has_set_cookie=*/true, + /*has_partitioned=*/true}, + {CookieUrl()}); + + WaitForPage(CookieUrl()); + + auto cookies = GetCookies(CookieUrl()); + EXPECT_EQ(1u, cookies.size()); + EXPECT_FALSE(cookies[0].IsPartitioned()); +} + +// The partitioned cookie should stay partitioned since we should not check +// the OT token on responses with no Set-Cookie header. +IN_PROC_BROWSER_TEST_F(PartitionedCookiesOriginTrialBrowserTest, NoSetCookie) { + SetCookie( + "__Host-foo", "bar", CookieUrl(), + net::CookiePartitionKey::FromURLForTesting(GURL("https://example.com"))); + SetTestOptions({/*has_ot_token=*/false, /*valid_ot_token=*/false, + /*has_set_cookie=*/false, + /*has_partitioned=*/true}, + {CookieUrl()}); + + WaitForPage(CookieUrl()); + + auto cookies = GetCookies(CookieUrl()); + EXPECT_EQ(1u, cookies.size()); + EXPECT_TRUE(cookies[0].IsPartitioned()); +} + +// The partitioned cookie should stay partitioned since we should not check +// the OT token on responses with a Set-Cookie header without Partitioned. +IN_PROC_BROWSER_TEST_F(PartitionedCookiesOriginTrialBrowserTest, + NoPartitioned) { + SetCookie( + "__Host-foo", "bar", CookieUrl(), + net::CookiePartitionKey::FromURLForTesting(GURL("https://example.com"))); + SetTestOptions({/*has_ot_token=*/false, /*valid_ot_token=*/false, + /*has_set_cookie=*/true, + /*has_partitioned=*/false}, + {CookieUrl()}); + + WaitForPage(CookieUrl()); + + auto cookies = GetCookies(CookieUrl()); + EXPECT_EQ(1u, cookies.size()); + EXPECT_TRUE(cookies[0].IsPartitioned()); +} + +// This class tests the origin trial mechanism for partitioned cookies. +// Partitioned cookies should be reverted to unpartitioned if the navigation +// has a Set-Cookie header with the Partitioned attribute and the site does +// not send a valid Origin-Trial header. +// This test exercises navigation requests in <iframe> embeds. +class EmbedPartitionedCookiesOriginTrialBrowserTest + : public PartitionedCookiesOriginTrialBrowserTest { + public: + void SetUpOnMainThread() override { + url_loader_interceptor_ = + std::make_unique<URLLoaderInterceptor>(base::BindRepeating( + &EmbedPartitionedCookiesOriginTrialBrowserTest::InterceptRequest, + base::Unretained(this))); + } + + const char* OriginTrialToken() const override { + // The test Origin Trial token was generated by running: + // python tools/origin_trials/generate_token.py https://127.0.0.1:44444 \ + // PartitionedCookies \ + // --expire-timestamp=2000000000 + // --is-third-party + return "A1mBOyrOKGAaaoT8mjM1qSNrOSrdDUa9WyqicVLlDGW3feIBSdWqSiHDAXUeKkGKaVq" + "UiCX8avwCM0gpG5LtxgAAAAByeyJvcmlnaW4iOiAiaHR0cHM6Ly8xMjcuMC4wLjE6ND" + "Q0NDQiLCAiZmVhdHVyZSI6ICJQYXJ0aXRpb25lZENvb2tpZXMiLCAiZXhwaXJ5IjogM" + "jAwMDAwMDAwMCwgImlzVGhpcmRQYXJ0eSI6IHRydWV9"; + } + + // We use URLLoaderInterceptor because we cannot control which port that + // EmbeddedTestServer uses. Since origin trials depend on the entire origin + // (including port) we need to intercept the requests using + // URLLoaderInterceptor. + bool InterceptRequest(URLLoaderInterceptor::RequestParams* params) { + if (expected_request_urls_.find(params->url_request.url) == + expected_request_urls_.end()) { + return false; + } + + if (params->url_request.url == TopLevelUrl()) { + std::string headers = "HTTP/1.1 200 OK\nContent-Type: text/html\n"; + std::string body = "<html><body><iframe src=\""; + base::StrAppend(&body, {CookieUrl().spec(), "\"></body></html>"}); + URLLoaderInterceptor::WriteResponse(headers, body, params->client.get(), + absl::nullopt, + /*url=*/params->url_request.url); + return true; + } + + return PartitionedCookiesOriginTrialBrowserTest::InterceptRequest(params); + } + + GURL TopLevelUrl() { return GURL("https://mysite.com:44444"); } +}; + +// Test that the partitioned cookie set before the request remains partitioned +// when the site sends a Set-Cookie header with the Partitioned attribute +// and a valid OT token. +IN_PROC_BROWSER_TEST_F(EmbedPartitionedCookiesOriginTrialBrowserTest, + ValidParticipant) { + SetCookie( + "__Host-foo", "bar", CookieUrl(), + net::CookiePartitionKey::FromURLForTesting(GURL("https://example.com"))); + SetTestOptions( + {/*has_ot_token=*/true, /*valid_ot_token=*/true, /*has_set_cookie=*/true, + /*has_partitioned=*/true}, + {TopLevelUrl(), CookieUrl()}); + + WaitForPage(TopLevelUrl()); + + auto cookies = GetCookies(CookieUrl()); + EXPECT_EQ(1u, cookies.size()); + EXPECT_TRUE(cookies[0].IsPartitioned()); +} + +// Test that the partitioned cookie is reverted to unpartitioned if the site +// sends a Set-Cookie with Partitioned but an invalid OT token. +IN_PROC_BROWSER_TEST_F(EmbedPartitionedCookiesOriginTrialBrowserTest, + InvalidToken) { + SetCookie( + "__Host-foo", "bar", CookieUrl(), + net::CookiePartitionKey::FromURLForTesting(GURL("https://example.com"))); + SetTestOptions( + {/*has_ot_token=*/true, /*valid_ot_token=*/false, /*has_set_cookie=*/true, + /*has_partitioned=*/true}, + {TopLevelUrl(), CookieUrl()}); + + WaitForPage(TopLevelUrl()); + + auto cookies = GetCookies(CookieUrl()); + EXPECT_EQ(1u, cookies.size()); + EXPECT_FALSE(cookies[0].IsPartitioned()); +} + +// Test that the partitioned cookie is reverted to unpartitioned if the site +// sends a Set-Cookie with Partitioned but do not send an OT token. +IN_PROC_BROWSER_TEST_F(EmbedPartitionedCookiesOriginTrialBrowserTest, NoToken) { + SetCookie( + "__Host-foo", "bar", CookieUrl(), + net::CookiePartitionKey::FromURLForTesting(GURL("https://example.com"))); + SetTestOptions({/*has_ot_token=*/false, /*valid_ot_token=*/false, + /*has_set_cookie=*/true, + /*has_partitioned=*/true}, + {TopLevelUrl(), CookieUrl()}); + + WaitForPage(TopLevelUrl()); + + auto cookies = GetCookies(CookieUrl()); + EXPECT_EQ(1u, cookies.size()); + EXPECT_FALSE(cookies[0].IsPartitioned()); +} + +// The partitioned cookie should stay partitioned since we should not check +// the OT token on responses with no Set-Cookie header. +IN_PROC_BROWSER_TEST_F(EmbedPartitionedCookiesOriginTrialBrowserTest, + NoSetCookie) { + SetCookie( + "__Host-foo", "bar", CookieUrl(), + net::CookiePartitionKey::FromURLForTesting(GURL("https://example.com"))); + SetTestOptions({/*has_ot_token=*/false, /*valid_ot_token=*/false, + /*has_set_cookie=*/false, + /*has_partitioned=*/true}, + {TopLevelUrl(), CookieUrl()}); + + WaitForPage(TopLevelUrl()); + + auto cookies = GetCookies(CookieUrl()); + EXPECT_EQ(1u, cookies.size()); + EXPECT_TRUE(cookies[0].IsPartitioned()); +} + +// The partitioned cookie should stay partitioned since we should not check +// the OT token on responses with a Set-Cookie header without Partitioned. +IN_PROC_BROWSER_TEST_F(EmbedPartitionedCookiesOriginTrialBrowserTest, + NoPartitioned) { + SetCookie( + "__Host-foo", "bar", CookieUrl(), + net::CookiePartitionKey::FromURLForTesting(GURL("https://example.com"))); + SetTestOptions({/*has_ot_token=*/false, /*valid_ot_token=*/false, + /*has_set_cookie=*/true, + /*has_partitioned=*/false}, + {TopLevelUrl(), CookieUrl()}); + + WaitForPage(TopLevelUrl()); + + auto cookies = GetCookies(CookieUrl()); + EXPECT_EQ(1u, cookies.size()); + EXPECT_TRUE(cookies[0].IsPartitioned()); +} + +// This test exercises the partitioned cookie origin trial for subresource +// requests. This browser test is meant to verify the feature works end-to-end +// though there is nothing about this test particularly related to navigation. +// +// I put the test here because I can subclass other partitioned cookies origin +// trial tests that do test navigation requests to reuse the test +// infrastructure. +// TODO(https://crbug.com/1296161): Move to another file/delete this test when +// OT is over. +class SubresourcePartitionedCookiesOriginTrialBrowserTest + : public EmbedPartitionedCookiesOriginTrialBrowserTest { + public: + void SetUpOnMainThread() override { + url_loader_interceptor_ = std::make_unique< + URLLoaderInterceptor>(base::BindRepeating( + &SubresourcePartitionedCookiesOriginTrialBrowserTest::InterceptRequest, + base::Unretained(this))); + } + + bool InterceptRequest(URLLoaderInterceptor::RequestParams* params) { + if (expected_request_urls_.find(params->url_request.url) == + expected_request_urls_.end()) { + return false; + } + + if (params->url_request.url == TopLevelUrl()) { + std::string headers = "HTTP/1.1 200 OK\nContent-Type: text/html\n"; + std::string body = "<html><body><script src=\""; + base::StrAppend(&body, + {CookieUrl().spec(), "\"></script></body></html>"}); + URLLoaderInterceptor::WriteResponse(headers, body, params->client.get(), + absl::nullopt, + /*url=*/params->url_request.url); + return true; + } + + std::string headers = "HTTP/1.1 200 OK\nContent-Type: text/javascript\n"; + std::string body = "console.log('Hello world!');"; + if (test_options_.has_set_cookie) { + base::StrAppend( + &headers, + {"Set-Cookie: __Host-foo=bar; Secure; Path=/; SameSite=None;", + test_options_.has_partitioned ? " Partitioned" : "", "\n"}); + } + if (test_options_.has_ot_token) { + base::StrAppend( + &headers, + {"Origin-Trial: ", + test_options_.valid_ot_token ? OriginTrialToken() : "invalid", + "\n"}); + } + URLLoaderInterceptor::WriteResponse(headers, body, params->client.get(), + absl::nullopt, + /*url=*/params->url_request.url); + return true; + } +}; + +// Test that the partitioned cookie set before the request remains partitioned +// when the site sends a Set-Cookie header with the Partitioned attribute +// and a valid OT token. +IN_PROC_BROWSER_TEST_F(SubresourcePartitionedCookiesOriginTrialBrowserTest, + ValidParticipant) { + SetCookie( + "__Host-foo", "bar", CookieUrl(), + net::CookiePartitionKey::FromURLForTesting(GURL("https://example.com"))); + SetTestOptions( + {/*has_ot_token=*/true, /*valid_ot_token=*/true, /*has_set_cookie=*/true, + /*has_partitioned=*/true}, + {TopLevelUrl(), CookieUrl()}); + + WaitForPage(TopLevelUrl()); + + auto cookies = GetCookies(CookieUrl()); + EXPECT_EQ(1u, cookies.size()); + EXPECT_TRUE(cookies[0].IsPartitioned()); +} + +// Test that the partitioned cookie is reverted to unpartitioned if the site +// sends a Set-Cookie with Partitioned but an invalid OT token. +IN_PROC_BROWSER_TEST_F(SubresourcePartitionedCookiesOriginTrialBrowserTest, + InvalidToken) { + SetCookie( + "__Host-foo", "bar", CookieUrl(), + net::CookiePartitionKey::FromURLForTesting(GURL("https://example.com"))); + SetTestOptions( + {/*has_ot_token=*/true, /*valid_ot_token=*/false, /*has_set_cookie=*/true, + /*has_partitioned=*/true}, + {TopLevelUrl(), CookieUrl()}); + + WaitForPage(TopLevelUrl()); + + auto cookies = GetCookies(CookieUrl()); + EXPECT_EQ(1u, cookies.size()); + EXPECT_FALSE(cookies[0].IsPartitioned()); +} + +// Test that the partitioned cookie is reverted to unpartitioned if the site +// sends a Set-Cookie with Partitioned but do not send an OT token. +IN_PROC_BROWSER_TEST_F(SubresourcePartitionedCookiesOriginTrialBrowserTest, + NoToken) { + SetCookie( + "__Host-foo", "bar", CookieUrl(), + net::CookiePartitionKey::FromURLForTesting(GURL("https://example.com"))); + SetTestOptions({/*has_ot_token=*/false, /*valid_ot_token=*/false, + /*has_set_cookie=*/true, + /*has_partitioned=*/true}, + {TopLevelUrl(), CookieUrl()}); + + WaitForPage(TopLevelUrl()); + + auto cookies = GetCookies(CookieUrl()); + EXPECT_EQ(1u, cookies.size()); + EXPECT_FALSE(cookies[0].IsPartitioned()); +} + +// The partitioned cookie should stay partitioned since we should not check +// the OT token on responses with no Set-Cookie header. +IN_PROC_BROWSER_TEST_F(SubresourcePartitionedCookiesOriginTrialBrowserTest, + NoSetCookie) { + SetCookie( + "__Host-foo", "bar", CookieUrl(), + net::CookiePartitionKey::FromURLForTesting(GURL("https://example.com"))); + SetTestOptions({/*has_ot_token=*/false, /*valid_ot_token=*/false, + /*has_set_cookie=*/false, + /*has_partitioned=*/true}, + {TopLevelUrl(), CookieUrl()}); + + WaitForPage(TopLevelUrl()); + + auto cookies = GetCookies(CookieUrl()); + EXPECT_EQ(1u, cookies.size()); + EXPECT_TRUE(cookies[0].IsPartitioned()); +} + +// The partitioned cookie should stay partitioned since we should not check +// the OT token on responses with a Set-Cookie header without Partitioned. +IN_PROC_BROWSER_TEST_F(SubresourcePartitionedCookiesOriginTrialBrowserTest, + NoPartitioned) { + SetCookie( + "__Host-foo", "bar", CookieUrl(), + net::CookiePartitionKey::FromURLForTesting(GURL("https://example.com"))); + SetTestOptions({/*has_ot_token=*/false, /*valid_ot_token=*/false, + /*has_set_cookie=*/true, + /*has_partitioned=*/false}, + {TopLevelUrl(), CookieUrl()}); + + WaitForPage(TopLevelUrl()); + + auto cookies = GetCookies(CookieUrl()); + EXPECT_EQ(1u, cookies.size()); + EXPECT_TRUE(cookies[0].IsPartitioned()); +} + } // namespace } // namespace content
diff --git a/content/browser/renderer_host/navigation_request.cc b/content/browser/renderer_host/navigation_request.cc index c91c101..875b85e 100644 --- a/content/browser/renderer_host/navigation_request.cc +++ b/content/browser/renderer_host/navigation_request.cc
@@ -29,6 +29,7 @@ #include "base/strings/string_util.h" #include "base/strings/stringprintf.h" #include "base/system/sys_info.h" +#include "base/time/time.h" #include "base/timer/elapsed_timer.h" #include "base/trace_event/trace_conversion_helper.h" #include "build/build_config.h" @@ -109,12 +110,14 @@ #include "content/public/common/url_constants.h" #include "content/public/common/url_utils.h" #include "mojo/public/cpp/system/data_pipe.h" +#include "net/base/features.h" #include "net/base/filename_util.h" #include "net/base/ip_endpoint.h" #include "net/base/load_flags.h" #include "net/base/net_errors.h" #include "net/base/registry_controlled_domains/registry_controlled_domain.h" #include "net/base/url_util.h" +#include "net/cookies/parsed_cookie.h" #include "net/http/http_request_headers.h" #include "net/http/http_status_code.h" #include "net/url_request/redirect_info.h" @@ -147,6 +150,9 @@ #include "third_party/blink/public/common/frame/frame_owner_element_type.h" #include "third_party/blink/public/common/navigation/navigation_params_mojom_traits.h" #include "third_party/blink/public/common/navigation/navigation_policy.h" +#include "third_party/blink/public/common/origin_trials/trial_token.h" +#include "third_party/blink/public/common/origin_trials/trial_token_result.h" +#include "third_party/blink/public/common/origin_trials/trial_token_validator.h" #include "third_party/blink/public/common/permissions_policy/document_policy.h" #include "third_party/blink/public/common/renderer_preferences/renderer_preferences.h" #include "third_party/blink/public/common/security/address_space_feature.h" @@ -942,11 +948,10 @@ } // If the response does not contain an Accept-CH header, then remove the -// Sec-CH-UA-Reduced, Sec-CH-UA-Full, or Sec-CH-Partitioned-Cookies, client -// hint from the Accept-CH cache, if it exists, for the response origin. The -// `client_hints` vector also has kUaReduced or kFullUserAgent removed from it -// if the Accept-CH response header doesn't exist, and cookies are -// un-partitioned if that feature is enabled. +// Sec-CH-UA-Reduced or Sec-CH-UA-Full client hint from the Accept-CH cache, if +// it exists, for the response origin. The `client_hints` vector also has +// kUaReduced or kFullUserAgent removed from it if the Accept-CH response header +// doesn't exist, and cookies are un-partitioned if that feature is enabled. void RemoveOriginTrialHintsFromAcceptCH( const GURL& url, ClientHintsControllerDelegate* delegate, @@ -958,18 +963,16 @@ if (!response || response->parsed_headers->accept_ch) return; - // For Chrome to continue to send Sec-CH-UA-Reduced, Sec-CH-UA-Full, or - // Sec-CH-Partitioned-Cookies, the server must continue replying with: + // For Chrome to continue to send Sec-CH-UA-Reduced or Sec-CH-UA-Full, the + // server must continue replying with: // - a valid Origin Trial token. - // - Accept-CH header with Sec-CH-UA-Reduced, Sec-CH-UA-Full, or - // Sec-CH-Partitioned-Cookies as a value. + // - Accept-CH header with Sec-CH-UA-Reduced or Sec-CH-UA-Full as a value. // // Here, it did not. So it gets removed from the persisted client hints // for the next request. std::vector<network::mojom::WebClientHintsType> hints_to_remove = { network::mojom::WebClientHintsType::kUAReduced, - network::mojom::WebClientHintsType::kFullUserAgent, - network::mojom::WebClientHintsType::kPartitionedCookies}; + network::mojom::WebClientHintsType::kFullUserAgent}; bool need_update_storage = false; for (const auto& hint : hints_to_remove) { if (base::Contains(client_hints, hint)) { @@ -982,10 +985,46 @@ frame_tree_node->GetParentOrOuterDocument(), delegate, client_hints); } +} - if (auto* cookie_manager = frame_tree_node->current_frame_host() - ->GetStoragePartition() - ->GetCookieManagerForBrowserProcess()) { +bool IsValidPartitionedCookiesOriginTrial( + const GURL& url, + const net::HttpResponseHeaders* response_headers) { + blink::TrialTokenValidator validator; + if (!validator.IsTrialPossibleOnOrigin(url)) + return false; + // Since third-party requests can participate in the CHIPS origin trial and + // typically the Origin-Trial header is reserved for requests from the + // top-level site, we cannot use validator.RequestEnablesFeature here. + url::Origin origin = url::Origin::Create(url); + url::Origin third_party_origins[] = {url::Origin::Create(url)}; + size_t iter = 0; + std::string token; + base::Time now(base::Time::Now()); + while (response_headers->EnumerateHeader(&iter, "Origin-Trial", &token)) { + blink::TrialTokenResult result = + validator.ValidateToken(token, origin, third_party_origins, now); + if (result.Status() == blink::OriginTrialTokenStatus::kSuccess) { + if (result.ParsedToken()->feature_name() == "PartitionedCookies") { + return true; + } + } + } + return false; +} + +// For the partitioned cookies OT, we check if the response has a Set-Cookie +// header with a partitioned cookie. If it does, we validate the OT token +// otherwise we convert the URL's partitioned cookies to unpartitioned. +void CheckPartitionedCookiesOriginTrial( + const network::mojom::URLResponseHead* response, + const GURL& url, + network::mojom::CookieManager* cookie_manager) { + if (!base::FeatureList::IsEnabled(net::features::kPartitionedCookies) || + !response || !cookie_manager || !response->has_partitioned_cookie) { + return; + } + if (!IsValidPartitionedCookiesOriginTrial(url, response->headers.get())) { cookie_manager->ConvertPartitionedCookiesToUnpartitioned(url); } } @@ -4715,6 +4754,11 @@ commit_params_->enabled_client_hints, frame_tree_node_); } + CheckPartitionedCookiesOriginTrial(response(), common_params_->url, + frame_tree_node_->current_frame_host() + ->GetStoragePartition() + ->GetCookieManagerForBrowserProcess()); + // Generate a UKM source and track it on NavigationRequest. This will be // passed down to the blink::Document to be created, if any, and used for UKM // source creation when navigation has successfully committed.
diff --git a/content/browser/renderer_host/render_widget_host_view_aura.cc b/content/browser/renderer_host/render_widget_host_view_aura.cc index 224a021..1da8f98 100644 --- a/content/browser/renderer_host/render_widget_host_view_aura.cc +++ b/content/browser/renderer_host/render_widget_host_view_aura.cc
@@ -1598,7 +1598,7 @@ #endif -#if BUILDFLAG(IS_CHROMEOS_ASH) +#if BUILDFLAG(IS_CHROMEOS) gfx::Range RenderWidgetHostViewAura::GetAutocorrectRange() const { if (!text_input_manager_ || !text_input_manager_->GetActiveWidget()) return gfx::Range(); @@ -1672,7 +1672,7 @@ } absl::optional<ui::GrammarFragment> -RenderWidgetHostViewAura::GetGrammarFragmentAtCursor() { +RenderWidgetHostViewAura::GetGrammarFragmentAtCursor() const { if (!text_input_manager_ || !text_input_manager_->GetActiveWidget()) return absl::nullopt; gfx::Range selection_range;
diff --git a/content/browser/renderer_host/render_widget_host_view_aura.h b/content/browser/renderer_host/render_widget_host_view_aura.h index 80faaf0..1b18939 100644 --- a/content/browser/renderer_host/render_widget_host_view_aura.h +++ b/content/browser/renderer_host/render_widget_host_view_aura.h
@@ -250,11 +250,12 @@ const std::vector<ui::ImeTextSpan>& ui_ime_text_spans) override; #endif -#if BUILDFLAG(IS_CHROMEOS_ASH) +#if BUILDFLAG(IS_CHROMEOS) gfx::Range GetAutocorrectRange() const override; gfx::Rect GetAutocorrectCharacterBounds() const override; bool SetAutocorrectRange(const gfx::Range& range) override; - absl::optional<ui::GrammarFragment> GetGrammarFragmentAtCursor() override; + absl::optional<ui::GrammarFragment> GetGrammarFragmentAtCursor() + const override; bool ClearGrammarFragments(const gfx::Range& range) override; bool AddGrammarFragments( const std::vector<ui::GrammarFragment>& fragments) override;
diff --git a/content/child/runtime_features.cc b/content/child/runtime_features.cc index 44b16d2..a5a33ea4 100644 --- a/content/child/runtime_features.cc +++ b/content/child/runtime_features.cc
@@ -406,8 +406,6 @@ blink::features::kClientHintThirdPartyDelegation}, {"UserAgentReduction", blink::features::kReduceUserAgent}, {"UserAgentFull", blink::features::kFullUserAgent}, - {"ClientHintPartitiondCookies", - blink::features::kClientHintsPartitionedCookies}, {"WindowPlacement", blink::features::kWindowPlacement}, {"WindowPlacementFullscreenOnScreensChange", blink::features::kWindowPlacementFullscreenOnScreensChange},
diff --git a/content/public/test/url_loader_interceptor.cc b/content/public/test/url_loader_interceptor.cc index 5f951b0..68ed149 100644 --- a/content/public/test/url_loader_interceptor.cc +++ b/content/public/test/url_loader_interceptor.cc
@@ -30,6 +30,7 @@ #include "content/public/test/mock_render_process_host.h" #include "mojo/public/cpp/bindings/receiver.h" #include "mojo/public/cpp/bindings/receiver_set.h" +#include "net/cookies/parsed_cookie.h" #include "net/http/http_util.h" #include "net/test/embedded_test_server/request_handler_util.h" #include "services/network/public/cpp/features.h" @@ -585,6 +586,14 @@ network::PopulateParsedHeaders(response->headers.get(), *url); } response->ssl_info = std::move(ssl_info); + size_t iter = 0; + std::string cookie_line; + while (info.headers->EnumerateHeader(&iter, "Set-Cookie", &cookie_line)) { + if (net::ParsedCookie(cookie_line).IsPartitioned()) { + response->has_partitioned_cookie = true; + break; + } + } mojo::ScopedDataPipeProducerHandle producer_handle; mojo::ScopedDataPipeConsumerHandle consumer_handle;
diff --git a/content/renderer/render_frame_impl.cc b/content/renderer/render_frame_impl.cc index a6b89dd3..1138a40 100644 --- a/content/renderer/render_frame_impl.cc +++ b/content/renderer/render_frame_impl.cc
@@ -5000,10 +5000,11 @@ } } - // Ensure we will propagate frame intersections when the main frame commits - // even if the intersection does not change across navigations. + // Ensure we will propagate the main frame and viewport rect when the main + // frame commits even if the rect does not change across navigations. if (IsMainFrame()) { main_frame_intersection_rect_.reset(); + main_frame_viewport_rect_.reset(); } }
diff --git a/content/renderer/render_frame_impl_browsertest.cc b/content/renderer/render_frame_impl_browsertest.cc index c86da20..1125932 100644 --- a/content/renderer/render_frame_impl_browsertest.cc +++ b/content/renderer/render_frame_impl_browsertest.cc
@@ -219,13 +219,19 @@ const gfx::Rect& intersection_rect) override { last_intersection_rect_ = intersection_rect; } + void OnMainFrameViewportRectangleChanged( + const gfx::Rect& viewport_rect) override { + last_viewport_rect_ = viewport_rect; + } - bool visible() { return visible_; } - gfx::Rect last_intersection_rect() { return last_intersection_rect_; } + bool visible() const { return visible_; } + gfx::Rect last_intersection_rect() const { return last_intersection_rect_; } + gfx::Rect last_viewport_rect() const { return last_viewport_rect_; } private: bool visible_; gfx::Rect last_intersection_rect_; + gfx::Rect last_viewport_rect_; }; // Verify that a frame with a RenderFrameProxy as a parent has its own @@ -460,6 +466,20 @@ EXPECT_EQ(observer.last_intersection_rect(), mainframe_intersection); } +TEST_F(RenderFrameImplTest, MainFrameViewportRectRecorded) { + RenderFrameTestObserver observer(GetMainRenderFrame()); + gfx::Rect mainframe_viewport(0, 0, 200, 140); + GetMainRenderFrame()->OnMainFrameViewportRectangleChanged(mainframe_viewport); + EXPECT_EQ(observer.last_viewport_rect(), mainframe_viewport); + + // After a navigation, the notification of `mainframe_viewport` should be + // propagated to `RenderFrameTestObserver` again for the new document. + LoadHTML(kParentFrameHTML); + RenderFrameTestObserver observer2(GetMainRenderFrame()); + GetMainRenderFrame()->OnMainFrameViewportRectangleChanged(mainframe_viewport); + EXPECT_EQ(observer2.last_viewport_rect(), mainframe_viewport); +} + // Used to annotate the source of an interface request. struct SourceAnnotation { // The URL of the active document in the frame, at the time the interface was
diff --git a/content/test/BUILD.gn b/content/test/BUILD.gn index daeebae6..4c98b05 100644 --- a/content/test/BUILD.gn +++ b/content/test/BUILD.gn
@@ -3067,8 +3067,11 @@ data = [ "//content/test/gpu/run_pytype.py", + "//content/test/gpu/validate_tag_consistency.py", "//build/util/lib/results/", + "//content/test/gpu/flake_suppressor/", + "//testing/unexpected_passes_common/", ] data_deps = [
diff --git a/content/test/gpu/gpu_tests/test_expectations/pixel_expectations.txt b/content/test/gpu/gpu_tests/test_expectations/pixel_expectations.txt index 9bf5d44..f8f6df28 100644 --- a/content/test/gpu/gpu_tests/test_expectations/pixel_expectations.txt +++ b/content/test/gpu/gpu_tests/test_expectations/pixel_expectations.txt
@@ -408,6 +408,8 @@ # Vulkan Swiftshader WebGPU unaccelerated OffscreenCanvas missing source image crbug.com/1307787 [ linux skia-renderer-vulkan ] Pixel_VulkanSwiftShader_WebGPUImportVideoFrameUnacceleratedOffscreenCanvas [ Failure ] +crbug.com/1203317 [ android android-shield-android-tv no-passthrough ] Pixel_OffscreenCanvasWebGLPaintAfterResize [ RetryOnFailure ] + ####################################################################### # Automated Entries After This Point - Do Not Manually Add Below Here # #######################################################################
diff --git a/content/test/gpu/run_pytype.py b/content/test/gpu/run_pytype.py index 5898544e..522fa96 100755 --- a/content/test/gpu/run_pytype.py +++ b/content/test/gpu/run_pytype.py
@@ -138,6 +138,14 @@ args.output_file, sink_client) sys.exit(0) + # Strangely, pytype won't complain if you tell it to analyze a directory that + # doesn't exist, which could potentially lead to code not being analyzed if + # it's added here but not added to the isolate. So, ensure that everything we + # expect to analyze actually exists. + for f in FILES_AND_DIRECTORIES_TO_CHECK: + if not os.path.exists(os.path.join(GPU_DIR, f)): + raise RuntimeError('Requested file or directory %s does not exist.' % f) + # pytype looks for a 'python' or 'python3' executable in PATH, so make sure # that the Python 3 executable from vpython is in the path. executable_dir = os.path.dirname(sys.executable)
diff --git a/fuchsia/base/BUILD.gn b/fuchsia/base/BUILD.gn index 3d6ec430..ae306d8 100644 --- a/fuchsia/base/BUILD.gn +++ b/fuchsia/base/BUILD.gn
@@ -88,26 +88,6 @@ ] } -source_set("legacymetrics") { - visibility = [ - ":cr_fuchsia_base_unittests__exec", - "//chromecast/internal/*", - "//fuchsia/engine/*", - ] - sources = [ - "legacymetrics_client.cc", - "legacymetrics_client.h", - "legacymetrics_histogram_flattener.cc", - "legacymetrics_histogram_flattener.h", - "legacymetrics_user_event_recorder.cc", - "legacymetrics_user_event_recorder.h", - ] - public = [ "legacymetrics_client.h" ] - deps = [ "//base" ] - public_deps = [ "//third_party/fuchsia-sdk/sdk/fidl/fuchsia.legacymetrics" ] - friend = [ ":*" ] -} - static_library("run_all_integration_tests") { testonly = true visibility = [ @@ -124,12 +104,8 @@ "agent_impl_unittest.cc", "config_reader_unittest.cc", "inspect_unittest.cc", - "legacymetrics_client_unittest.cc", - "legacymetrics_histogram_flattener_unittest.cc", - "legacymetrics_user_event_recorder_unittest.cc", ] deps = [ - ":legacymetrics", ":modular", "//base", "//base:testfidl",
diff --git a/fuchsia/base/legacymetrics_histogram_flattener.h b/fuchsia/base/legacymetrics_histogram_flattener.h deleted file mode 100644 index d4b147a..0000000 --- a/fuchsia/base/legacymetrics_histogram_flattener.h +++ /dev/null
@@ -1,20 +0,0 @@ -// Copyright 2020 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef FUCHSIA_BASE_LEGACYMETRICS_HISTOGRAM_FLATTENER_H_ -#define FUCHSIA_BASE_LEGACYMETRICS_HISTOGRAM_FLATTENER_H_ - -#include <fuchsia/legacymetrics/cpp/fidl.h> -#include <vector> - -#include "base/metrics/histogram_flattener.h" -#include "base/metrics/histogram_snapshot_manager.h" - -namespace cr_fuchsia { - -std::vector<fuchsia::legacymetrics::Histogram> GetLegacyMetricsDeltas(); - -} // namespace cr_fuchsia - -#endif // FUCHSIA_BASE_LEGACYMETRICS_HISTOGRAM_FLATTENER_H_
diff --git a/fuchsia/engine/BUILD.gn b/fuchsia/engine/BUILD.gn index e26e41cd..f44bcbaa 100644 --- a/fuchsia/engine/BUILD.gn +++ b/fuchsia/engine/BUILD.gn
@@ -96,6 +96,7 @@ "//components/embedder_support/origin_trials", "//components/favicon/content", "//components/favicon/core", + "//components/fuchsia_legacymetrics", "//components/keyed_service/content", "//components/media_control/browser", "//components/media_control/renderer", @@ -118,7 +119,6 @@ "//content/public/common", "//content/public/renderer", "//fuchsia/base", - "//fuchsia/base:legacymetrics", "//fuchsia/base:message_port", "//fuchsia/base:modular", "//fuchsia/engine/mojom",
diff --git a/fuchsia/engine/DEPS b/fuchsia/engine/DEPS index a5b1c33..388bc50 100644 --- a/fuchsia/engine/DEPS +++ b/fuchsia/engine/DEPS
@@ -31,4 +31,7 @@ "context_provider_impl_unittest\.cc": [ "+services/network/public/cpp/network_switches.h" ], + "web_engine_browser_main_parts\.cc": [ + "+components/fuchsia_legacymetrics/legacymetrics_client.h" + ] }
diff --git a/fuchsia/engine/browser/web_engine_browser_main_parts.cc b/fuchsia/engine/browser/web_engine_browser_main_parts.cc index 38aa87c..f0e5b63 100644 --- a/fuchsia/engine/browser/web_engine_browser_main_parts.cc +++ b/fuchsia/engine/browser/web_engine_browser_main_parts.cc
@@ -30,6 +30,7 @@ #include "base/threading/thread_restrictions.h" #include "base/threading/thread_task_runner_handle.h" #include "build/build_config.h" +#include "components/fuchsia_legacymetrics/legacymetrics_client.h" #include "content/public/browser/content_browser_client.h" #include "content/public/browser/gpu_data_manager.h" #include "content/public/browser/histogram_fetcher.h" @@ -40,7 +41,6 @@ #include "content/public/common/main_function_params.h" #include "content/public/common/result_codes.h" #include "fuchsia/base/inspect.h" -#include "fuchsia/base/legacymetrics_client.h" #include "fuchsia/engine/browser/context_impl.h" #include "fuchsia/engine/browser/web_engine_browser_context.h" #include "fuchsia/engine/browser/web_engine_devtools_controller.h" @@ -219,7 +219,7 @@ if (command_line->HasSwitch(switches::kUseLegacyMetricsService)) { legacy_metrics_client_ = - std::make_unique<cr_fuchsia::LegacyMetricsClient>(); + std::make_unique<fuchsia_legacymetrics::LegacyMetricsClient>(); // Add a hook to asynchronously pull metrics from child processes just prior // to uploading.
diff --git a/fuchsia/engine/browser/web_engine_browser_main_parts.h b/fuchsia/engine/browser/web_engine_browser_main_parts.h index 91bd530..681b71a 100644 --- a/fuchsia/engine/browser/web_engine_browser_main_parts.h +++ b/fuchsia/engine/browser/web_engine_browser_main_parts.h
@@ -29,7 +29,7 @@ class ContentBrowserClient; } -namespace cr_fuchsia { +namespace fuchsia_legacymetrics { class LegacyMetricsClient; } @@ -135,7 +135,8 @@ frame_host_bindings_; std::unique_ptr<WebEngineDevToolsController> devtools_controller_; - std::unique_ptr<cr_fuchsia::LegacyMetricsClient> legacy_metrics_client_; + std::unique_ptr<fuchsia_legacymetrics::LegacyMetricsClient> + legacy_metrics_client_; std::unique_ptr<media::FuchsiaCdmManager> cdm_manager_; // Used to respond to changes to the system's current locale.
diff --git a/gpu/vulkan/BUILD.gn b/gpu/vulkan/BUILD.gn index 7535d534..58092356 100644 --- a/gpu/vulkan/BUILD.gn +++ b/gpu/vulkan/BUILD.gn
@@ -109,6 +109,10 @@ defines = [ "IS_VULKAN_IMPL" ] + if (ozone_platform == "x11") { + defines += [ "OZONE_PLATFORM_IS_X11" ] + } + deps = [ ":buildflags", "//base",
diff --git a/gpu/vulkan/vulkan_util.cc b/gpu/vulkan/vulkan_util.cc index 6891236c..cdb228d 100644 --- a/gpu/vulkan/vulkan_util.cc +++ b/gpu/vulkan/vulkan_util.cc
@@ -198,9 +198,14 @@ return false; } } -#endif // !BUILDFLAG(IS_ANDROID) -#if BUILDFLAG(IS_ANDROID) +#if BUILDFLAG(IS_LINUX) && !defined(OZONE_PLATFORM_IS_X11) + // Vulkan is only supported with X11 on Linux for now. + return false; +#else + return true; +#endif +#else // BUILDFLAG(IS_ANDROID) if (vulkan_info.physical_devices.empty()) return false; @@ -255,9 +260,9 @@ // Imagination GPUs. if (device_info.properties.vendorID == kVendorImagination) return false; -#endif // BUILDFLAG(IS_ANDROID) return true; +#endif // BUILDFLAG(IS_ANDROID) } VkImageLayout GLImageLayoutToVkImageLayout(uint32_t layout) {
diff --git a/infra/config/generated/builders/ci/GPU Mac Builder/properties.json b/infra/config/generated/builders/ci/GPU Mac Builder/properties.json index a0ae05c..24aab43 100644 --- a/infra/config/generated/builders/ci/GPU Mac Builder/properties.json +++ b/infra/config/generated/builders/ci/GPU Mac Builder/properties.json
@@ -27,7 +27,7 @@ "apply_configs": [ "use_clang_coverage" ], - "config": "chromium_no_telemetry_dependencies" + "config": "chromium" } } },
diff --git a/infra/config/generated/builders/ci/Mac Builder/properties.json b/infra/config/generated/builders/ci/Mac Builder/properties.json index af5cfa92..f3cd705 100644 --- a/infra/config/generated/builders/ci/Mac Builder/properties.json +++ b/infra/config/generated/builders/ci/Mac Builder/properties.json
@@ -27,7 +27,7 @@ "apply_configs": [ "use_clang_coverage" ], - "config": "chromium_no_telemetry_dependencies" + "config": "chromium" } } },
diff --git "a/infra/config/generated/builders/ci/Mac Release \050Intel\051/properties.json" "b/infra/config/generated/builders/ci/Mac Release \050Intel\051/properties.json" index 55ccc4a..f67149d 100644 --- "a/infra/config/generated/builders/ci/Mac Release \050Intel\051/properties.json" +++ "b/infra/config/generated/builders/ci/Mac Release \050Intel\051/properties.json"
@@ -27,7 +27,7 @@ "apply_configs": [ "use_clang_coverage" ], - "config": "chromium_no_telemetry_dependencies" + "config": "chromium" } } },
diff --git "a/infra/config/generated/builders/ci/Mac Retina Release \050AMD\051/properties.json" "b/infra/config/generated/builders/ci/Mac Retina Release \050AMD\051/properties.json" index 1328272..a89e92f0 100644 --- "a/infra/config/generated/builders/ci/Mac Retina Release \050AMD\051/properties.json" +++ "b/infra/config/generated/builders/ci/Mac Retina Release \050AMD\051/properties.json"
@@ -27,7 +27,7 @@ "apply_configs": [ "use_clang_coverage" ], - "config": "chromium_no_telemetry_dependencies" + "config": "chromium" } } },
diff --git a/infra/config/generated/builders/ci/Mac10.13 Tests/properties.json b/infra/config/generated/builders/ci/Mac10.13 Tests/properties.json index 0df31684..3df545f5 100644 --- a/infra/config/generated/builders/ci/Mac10.13 Tests/properties.json +++ b/infra/config/generated/builders/ci/Mac10.13 Tests/properties.json
@@ -27,7 +27,7 @@ "apply_configs": [ "use_clang_coverage" ], - "config": "chromium_no_telemetry_dependencies" + "config": "chromium" } } },
diff --git a/infra/config/generated/builders/ci/Mac10.14 Tests/properties.json b/infra/config/generated/builders/ci/Mac10.14 Tests/properties.json index 0f088bc..20b7487 100644 --- a/infra/config/generated/builders/ci/Mac10.14 Tests/properties.json +++ b/infra/config/generated/builders/ci/Mac10.14 Tests/properties.json
@@ -27,7 +27,7 @@ "apply_configs": [ "use_clang_coverage" ], - "config": "chromium_no_telemetry_dependencies" + "config": "chromium" } } },
diff --git a/infra/config/generated/builders/ci/Mac10.15 Tests/properties.json b/infra/config/generated/builders/ci/Mac10.15 Tests/properties.json index f3cca73..eec3dcc 100644 --- a/infra/config/generated/builders/ci/Mac10.15 Tests/properties.json +++ b/infra/config/generated/builders/ci/Mac10.15 Tests/properties.json
@@ -27,7 +27,7 @@ "apply_configs": [ "use_clang_coverage" ], - "config": "chromium_no_telemetry_dependencies" + "config": "chromium" } } },
diff --git a/infra/config/generated/builders/ci/Mac11 Tests/properties.json b/infra/config/generated/builders/ci/Mac11 Tests/properties.json index 3ef8fb3..6f0d6ca2 100644 --- a/infra/config/generated/builders/ci/Mac11 Tests/properties.json +++ b/infra/config/generated/builders/ci/Mac11 Tests/properties.json
@@ -27,7 +27,7 @@ "apply_configs": [ "use_clang_coverage" ], - "config": "chromium_no_telemetry_dependencies" + "config": "chromium" } } },
diff --git a/infra/config/generated/builders/ci/mac-official/properties.json b/infra/config/generated/builders/ci/mac-official/properties.json index 87b58e30..9acf9519c 100644 --- a/infra/config/generated/builders/ci/mac-official/properties.json +++ b/infra/config/generated/builders/ci/mac-official/properties.json
@@ -23,7 +23,7 @@ "apply_configs": [ "checkout_pgo_profiles" ], - "config": "chromium_no_telemetry_dependencies" + "config": "chromium" } } }
diff --git a/infra/config/generated/builders/ci/mac-osxbeta-rel/properties.json b/infra/config/generated/builders/ci/mac-osxbeta-rel/properties.json index eb5eae8..48b8012 100644 --- a/infra/config/generated/builders/ci/mac-osxbeta-rel/properties.json +++ b/infra/config/generated/builders/ci/mac-osxbeta-rel/properties.json
@@ -27,7 +27,7 @@ "apply_configs": [ "use_clang_coverage" ], - "config": "chromium_no_telemetry_dependencies" + "config": "chromium" } } },
diff --git a/infra/config/generated/builders/try/mac-inverse-fieldtrials-fyi-rel/properties.json b/infra/config/generated/builders/try/mac-inverse-fieldtrials-fyi-rel/properties.json index 0f74b13..f48e476 100644 --- a/infra/config/generated/builders/try/mac-inverse-fieldtrials-fyi-rel/properties.json +++ b/infra/config/generated/builders/try/mac-inverse-fieldtrials-fyi-rel/properties.json
@@ -27,7 +27,7 @@ "apply_configs": [ "use_clang_coverage" ], - "config": "chromium_no_telemetry_dependencies" + "config": "chromium" } } }, @@ -55,7 +55,7 @@ "apply_configs": [ "use_clang_coverage" ], - "config": "chromium_no_telemetry_dependencies" + "config": "chromium" } } },
diff --git a/infra/config/generated/builders/try/mac-official/properties.json b/infra/config/generated/builders/try/mac-official/properties.json index 56df05c..046f5307 100644 --- a/infra/config/generated/builders/try/mac-official/properties.json +++ b/infra/config/generated/builders/try/mac-official/properties.json
@@ -23,7 +23,7 @@ "apply_configs": [ "checkout_pgo_profiles" ], - "config": "chromium_no_telemetry_dependencies" + "config": "chromium" } } }
diff --git a/infra/config/generated/builders/try/mac-osxbeta-rel/properties.json b/infra/config/generated/builders/try/mac-osxbeta-rel/properties.json index 94980fc..415d42a 100644 --- a/infra/config/generated/builders/try/mac-osxbeta-rel/properties.json +++ b/infra/config/generated/builders/try/mac-osxbeta-rel/properties.json
@@ -27,7 +27,7 @@ "apply_configs": [ "use_clang_coverage" ], - "config": "chromium_no_telemetry_dependencies" + "config": "chromium" } } },
diff --git a/infra/config/generated/builders/try/mac-rel-compilator/properties.json b/infra/config/generated/builders/try/mac-rel-compilator/properties.json index 69093c3..57e8184 100644 --- a/infra/config/generated/builders/try/mac-rel-compilator/properties.json +++ b/infra/config/generated/builders/try/mac-rel-compilator/properties.json
@@ -27,7 +27,7 @@ "apply_configs": [ "use_clang_coverage" ], - "config": "chromium_no_telemetry_dependencies" + "config": "chromium" } } }, @@ -55,7 +55,7 @@ "apply_configs": [ "use_clang_coverage" ], - "config": "chromium_no_telemetry_dependencies" + "config": "chromium" } } },
diff --git a/infra/config/generated/builders/try/mac-rel/properties.json b/infra/config/generated/builders/try/mac-rel/properties.json index 16056b0e..b3651b5 100644 --- a/infra/config/generated/builders/try/mac-rel/properties.json +++ b/infra/config/generated/builders/try/mac-rel/properties.json
@@ -31,7 +31,7 @@ "apply_configs": [ "use_clang_coverage" ], - "config": "chromium_no_telemetry_dependencies" + "config": "chromium" } } }, @@ -59,7 +59,7 @@ "apply_configs": [ "use_clang_coverage" ], - "config": "chromium_no_telemetry_dependencies" + "config": "chromium" } } },
diff --git a/infra/config/generated/builders/try/mac_chromium_10.13_rel_ng/properties.json b/infra/config/generated/builders/try/mac_chromium_10.13_rel_ng/properties.json index cac2dce..a910f52 100644 --- a/infra/config/generated/builders/try/mac_chromium_10.13_rel_ng/properties.json +++ b/infra/config/generated/builders/try/mac_chromium_10.13_rel_ng/properties.json
@@ -27,7 +27,7 @@ "apply_configs": [ "use_clang_coverage" ], - "config": "chromium_no_telemetry_dependencies" + "config": "chromium" } } },
diff --git a/infra/config/generated/builders/try/mac_chromium_10.14_rel_ng/properties.json b/infra/config/generated/builders/try/mac_chromium_10.14_rel_ng/properties.json index 199778a..72afc43 100644 --- a/infra/config/generated/builders/try/mac_chromium_10.14_rel_ng/properties.json +++ b/infra/config/generated/builders/try/mac_chromium_10.14_rel_ng/properties.json
@@ -27,7 +27,7 @@ "apply_configs": [ "use_clang_coverage" ], - "config": "chromium_no_telemetry_dependencies" + "config": "chromium" } } },
diff --git a/infra/config/generated/builders/try/mac_chromium_10.15_rel_ng/properties.json b/infra/config/generated/builders/try/mac_chromium_10.15_rel_ng/properties.json index cacbe0c..c2291b2 100644 --- a/infra/config/generated/builders/try/mac_chromium_10.15_rel_ng/properties.json +++ b/infra/config/generated/builders/try/mac_chromium_10.15_rel_ng/properties.json
@@ -27,7 +27,7 @@ "apply_configs": [ "use_clang_coverage" ], - "config": "chromium_no_telemetry_dependencies" + "config": "chromium" } } },
diff --git a/infra/config/generated/builders/try/mac_chromium_11.0_rel_ng/properties.json b/infra/config/generated/builders/try/mac_chromium_11.0_rel_ng/properties.json index f68b9f50..df75d681 100644 --- a/infra/config/generated/builders/try/mac_chromium_11.0_rel_ng/properties.json +++ b/infra/config/generated/builders/try/mac_chromium_11.0_rel_ng/properties.json
@@ -27,7 +27,7 @@ "apply_configs": [ "use_clang_coverage" ], - "config": "chromium_no_telemetry_dependencies" + "config": "chromium" } } },
diff --git a/infra/config/generated/builders/try/mac_chromium_compile_rel_ng/properties.json b/infra/config/generated/builders/try/mac_chromium_compile_rel_ng/properties.json index 622df7f6..a3ded6e 100644 --- a/infra/config/generated/builders/try/mac_chromium_compile_rel_ng/properties.json +++ b/infra/config/generated/builders/try/mac_chromium_compile_rel_ng/properties.json
@@ -27,7 +27,7 @@ "apply_configs": [ "use_clang_coverage" ], - "config": "chromium_no_telemetry_dependencies" + "config": "chromium" } } },
diff --git a/infra/config/subprojects/chromium/ci/chromium.gpu.star b/infra/config/subprojects/chromium/ci/chromium.gpu.star index cc984246..31de415 100644 --- a/infra/config/subprojects/chromium/ci/chromium.gpu.star +++ b/infra/config/subprojects/chromium/ci/chromium.gpu.star
@@ -105,7 +105,7 @@ branch_selector = branches.DESKTOP_EXTENDED_STABLE_MILESTONE, builder_spec = builder_config.builder_spec( gclient_config = builder_config.gclient_config( - config = "chromium_no_telemetry_dependencies", + config = "chromium", apply_configs = [ "use_clang_coverage", ],
diff --git a/infra/config/subprojects/chromium/ci/chromium.mac.star b/infra/config/subprojects/chromium/ci/chromium.mac.star index 27eacd3..3832e933 100644 --- a/infra/config/subprojects/chromium/ci/chromium.mac.star +++ b/infra/config/subprojects/chromium/ci/chromium.mac.star
@@ -56,7 +56,7 @@ branch_selector = branches.DESKTOP_EXTENDED_STABLE_MILESTONE, builder_spec = builder_config.builder_spec( gclient_config = builder_config.gclient_config( - config = "chromium_no_telemetry_dependencies", + config = "chromium", apply_configs = [ "use_clang_coverage", ],
diff --git a/infra/config/subprojects/chromium/ci/chromium.star b/infra/config/subprojects/chromium/ci/chromium.star index aa0dc72..15759d64 100644 --- a/infra/config/subprojects/chromium/ci/chromium.star +++ b/infra/config/subprojects/chromium/ci/chromium.star
@@ -473,7 +473,7 @@ branch_selector = branches.DESKTOP_EXTENDED_STABLE_MILESTONE, builder_spec = builder_config.builder_spec( gclient_config = builder_config.gclient_config( - config = "chromium_no_telemetry_dependencies", + config = "chromium", apply_configs = [ "checkout_pgo_profiles", ],
diff --git a/ios/chrome/app/strings/ios_strings.grd b/ios/chrome/app/strings/ios_strings.grd index 5b4b6f9..1e34f6d 100644 --- a/ios/chrome/app/strings/ios_strings.grd +++ b/ios/chrome/app/strings/ios_strings.grd
@@ -3450,10 +3450,10 @@ <message name="IDS_IOS_SHARED_HIGHLIGHT_REMOVE" desc="An option in a menu which appears after a user taps on a highlighted passage in a web page. Selecting this option causes the passage to no longer be highlighted. Title-cased. [Length: 25em]"> Remove </message> - <message name="IDS_IOS_IPH_BUBBLE_SNOOZE" desc="Label of a button that appears in an in-product help Tip. Pressing this button will dismiss the Tip, the user will be reminded of the same Tip later on. Title-cased."> + <message name="IDS_IOS_IPH_BUBBLE_SNOOZE" desc="Label of a button that appears in an in-product help Tip. Pressing this button will dismiss the Tip, the user will be reminded of the same Tip later on. Used as a button title in an info bubble." meaning="Used as a button title in an info bubble."> Remind Me Later </message> - <message name="IDS_IOS_PASSWORD_SUGGESTIONS_TIP_TITLE" desc="The title of the tip explaining to the user they can tap on Autofill password suggestions"> + <message name="IDS_IOS_PASSWORD_SUGGESTIONS_TIP_TITLE" desc="The title of the tip explaining to the user they can tap on Autofill password suggestions. Used as a title in an info bubble." meaning="Used as a title in an info bubble."> Autofill Passwords </message> </messages>
diff --git a/ios/chrome/browser/feature_engagement/BUILD.gn b/ios/chrome/browser/feature_engagement/BUILD.gn index d933f91..9dbfa46b 100644 --- a/ios/chrome/browser/feature_engagement/BUILD.gn +++ b/ios/chrome/browser/feature_engagement/BUILD.gn
@@ -77,7 +77,9 @@ ":eg_test_support+eg2", "//base", "//components/feature_engagement/public", - "//ios/chrome/app/strings:ios_strings_grit", + "//ios/chrome/app/strings", + "//ios/chrome/browser/passwords:eg_test_support+eg2", + "//ios/chrome/browser/ui/bubble:features", "//ios/chrome/browser/ui/popup_menu:constants", "//ios/chrome/browser/ui/table_view:constants", "//ios/chrome/test/earl_grey:eg_test_support+eg2",
diff --git a/ios/chrome/browser/feature_engagement/feature_engagement_app_interface.h b/ios/chrome/browser/feature_engagement/feature_engagement_app_interface.h index 33600d9..24304c61 100644 --- a/ios/chrome/browser/feature_engagement/feature_engagement_app_interface.h +++ b/ios/chrome/browser/feature_engagement/feature_engagement_app_interface.h
@@ -50,6 +50,11 @@ // FeatureEngagementTracker failed to load. + (BOOL)enableDefaultSiteViewTipTriggering [[nodiscard]]; +// Enables the Password Suggestions tip to be triggered. The tip is triggered +// only once the first time Autofill password suggestions are shown. Returns NO +// if FeatureEngagementTracker failed to load. ++ (BOOL)enablePasswordSuggestionsTipTriggering [[nodiscard]]; + // Starts manual page translation. + (void)showTranslate;
diff --git a/ios/chrome/browser/feature_engagement/feature_engagement_app_interface.mm b/ios/chrome/browser/feature_engagement/feature_engagement_app_interface.mm index b7d905d..19ea193 100644 --- a/ios/chrome/browser/feature_engagement/feature_engagement_app_interface.mm +++ b/ios/chrome/browser/feature_engagement/feature_engagement_app_interface.mm
@@ -209,6 +209,27 @@ return LoadFeatureEngagementTracker(); } ++ (BOOL)enablePasswordSuggestionsTipTriggering { + std::map<std::string, std::string> password_suggestions_tip_params; + + password_suggestions_tip_params["availability"] = "any"; + password_suggestions_tip_params["session_rate"] = "any"; + password_suggestions_tip_params["event_used"] = + "name:password_suggestions_shown;comparator:==0;window:90;" + "storage:360"; + password_suggestions_tip_params["event_trigger"] = + "name:password_suggestions_iph_triggered;comparator:==0;window:1825;" + "storage:1825"; + + ScopedFeatureListHolder::GetInstance() + ->CreateList() + .InitAndEnableFeatureWithParameters( + feature_engagement::kIPHPasswordSuggestionsFeature, + password_suggestions_tip_params); + + return LoadFeatureEngagementTracker(); +} + + (void)showTranslate { [chrome_test_util::HandlerForActiveBrowser() showTranslate]; }
diff --git a/ios/chrome/browser/feature_engagement/feature_engagement_egtest.mm b/ios/chrome/browser/feature_engagement/feature_engagement_egtest.mm index 399f5e9..ed5eb16 100644 --- a/ios/chrome/browser/feature_engagement/feature_engagement_egtest.mm +++ b/ios/chrome/browser/feature_engagement/feature_engagement_egtest.mm
@@ -7,14 +7,21 @@ #include "base/strings/sys_string_conversions.h" #import "base/test/ios/wait_util.h" #import "ios/chrome/browser/feature_engagement/feature_engagement_app_interface.h" +#import "ios/chrome/browser/passwords/password_manager_app_interface.h" +#import "ios/chrome/browser/ui/bubble/bubble_features.h" #import "ios/chrome/browser/ui/popup_menu/popup_menu_constants.h" #import "ios/chrome/browser/ui/table_view/table_view_navigation_controller_constants.h" +#include "ios/chrome/grit/ios_chromium_strings.h" #include "ios/chrome/grit/ios_strings.h" +#import "ios/chrome/test/earl_grey/chrome_actions.h" #import "ios/chrome/test/earl_grey/chrome_earl_grey.h" #import "ios/chrome/test/earl_grey/chrome_earl_grey_ui.h" #import "ios/chrome/test/earl_grey/chrome_matchers.h" #import "ios/chrome/test/earl_grey/chrome_test_case.h" +#include "ios/testing/earl_grey/app_launch_configuration.h" +#import "ios/testing/earl_grey/app_launch_manager.h" #import "ios/testing/earl_grey/earl_grey_test.h" +#import "net/base/mac/url_conversions.h" #include "net/test/embedded_test_server/embedded_test_server.h" #include "net/test/embedded_test_server/http_response.h" #include "net/test/embedded_test_server/request_handler_util.h" @@ -42,6 +49,12 @@ // URL path for a page with text in French. const char kFrenchPageURLPath[] = "/french"; +// URL path for a page with password field form. +constexpr char kPasswordForm[] = "/username_password_field_form.html"; + +// Element ID for the username field in the password form. +constexpr char kPasswordFormUsername[] = "username"; + // Matcher for the Reading List Text Badge. id<GREYMatcher> ReadingListTextBadge() { NSString* new_overflow_menu_accessibility_id = @@ -93,6 +106,12 @@ l10n_util::GetNSStringWithFixup(IDS_IOS_DEFAULT_PAGE_MODE_TIP)); } +// Matcher for the PasswordSuggestions tip. +id<GREYMatcher> PasswordSuggestionsTip() { + return grey_accessibilityLabel( + l10n_util::GetNSStringWithFixup(IDS_IOS_PASSWORD_SUGGESTIONS_TIP)); +} + // Opens the TabGrid and then opens a new tab. void OpenTabGridAndOpenTab() { [[EarlGrey selectElementWithMatcher:chrome_test_util::ShowTabsButton()] @@ -147,8 +166,16 @@ @implementation FeatureEngagementTestCase +- (AppLaunchConfiguration)appConfigurationForTestCase { + AppLaunchConfiguration config = [super appConfigurationForTestCase]; + // Flag to enable password suggestion highlight and tip. + config.features_enabled.push_back(kBubbleRichIPH); + return config; +} + - (void)tearDown { [FeatureEngagementAppInterface reset]; + [PasswordManagerAppInterface clearCredentials]; [super tearDown]; } @@ -509,4 +536,53 @@ assertWithMatcher:grey_nil()]; } +// Verifies that the password suggestion tip is displayed only the first time +// password suggestions are shown. +- (void)testPasswordSuggestionsTip { + GREYAssert( + [FeatureEngagementAppInterface enablePasswordSuggestionsTipTriggering], + @"Feature Engagement tracker did not load"); + self.testServer->AddDefaultHandlers(); + GREYAssertTrue(self.testServer->Start(), @"Test server failed to start"); + + // Save the password. + NSURL* URL = net::NSURLWithGURL(self.testServer->GetURL(kPasswordForm)); + [PasswordManagerAppInterface storeCredentialWithUsername:@"EgUsername" + password:@"EgPassword" + URL:URL]; + int credentialsCount = [PasswordManagerAppInterface storedCredentialsCount]; + GREYAssertEqual(1, credentialsCount, @"Wrong number of stored credentials."); + + // Reopen the page, and focus the login text fields. This should trigger the + // tip. + [ChromeEarlGreyUI openNewTab]; + [ChromeEarlGrey loadURL:self.testServer->GetURL(kPasswordForm)]; + [[EarlGrey selectElementWithMatcher:chrome_test_util::WebViewMatcher()] + performAction:chrome_test_util::TapWebElementWithId( + kPasswordFormUsername)]; + [ChromeEarlGrey + waitForSufficientlyVisibleElementWithMatcher:PasswordSuggestionsTip()]; + + // Dismiss the keyboard. + NSError* error = nil; + GREYAssert([EarlGrey dismissKeyboardWithError:&error] && error == nil, + @"Cannot dismiss the keyboard"); + + // Second time, the tip should no longer trigger. + [ChromeEarlGreyUI openNewTab]; + [ChromeEarlGrey loadURL:self.testServer->GetURL(kPasswordForm)]; + [[EarlGrey selectElementWithMatcher:chrome_test_util::WebViewMatcher()] + performAction:chrome_test_util::TapWebElementWithId( + kPasswordFormUsername)]; + ConditionBlock condition = ^{ + NSError* error = nil; + [[EarlGrey selectElementWithMatcher:PasswordSuggestionsTip()] + assertWithMatcher:grey_sufficientlyVisible() + error:&error]; + return error == nil; + }; + GREYAssert(!WaitUntilConditionOrTimeout(kWaitForUIElementTimeout, condition), + @"The password suggestion tip shouldn't appear"); +} + @end
diff --git a/ios/chrome/browser/translate/BUILD.gn b/ios/chrome/browser/translate/BUILD.gn index 84df6a4..6495676 100644 --- a/ios/chrome/browser/translate/BUILD.gn +++ b/ios/chrome/browser/translate/BUILD.gn
@@ -75,6 +75,7 @@ configs += [ "//build/config/compiler:enable_arc" ] testonly = true sources = [ + "chrome_ios_translate_client_unittest.mm", "language_detection_javascript_unittest.mm", "translate_service_ios_unittest.cc", ] @@ -84,10 +85,15 @@ "//base", "//base/test:test_support", "//components/language/core/browser", + "//components/language/ios/browser", "//components/translate/core/browser:test_support", + "//components/translate/core/common", "//components/translate/ios/browser", "//ios/chrome/browser", "//ios/chrome/browser/browser_state:test_support", + "//ios/chrome/browser/infobars", + "//ios/chrome/browser/language", + "//ios/chrome/browser/optimization_guide", "//ios/chrome/browser/web:web_internal", "//ios/chrome/common:string_util", "//ios/web/public",
diff --git a/ios/chrome/browser/translate/chrome_ios_translate_client_unittest.mm b/ios/chrome/browser/translate/chrome_ios_translate_client_unittest.mm new file mode 100644 index 0000000..00fa7e7 --- /dev/null +++ b/ios/chrome/browser/translate/chrome_ios_translate_client_unittest.mm
@@ -0,0 +1,66 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "ios/chrome/browser/translate/chrome_ios_translate_client.h" +#import "base/test/metrics/histogram_tester.h" +#import "base/test/scoped_feature_list.h" +#import "base/test/task_environment.h" +#import "components/language/ios/browser/ios_language_detection_tab_helper.h" +#import "components/translate/core/common/translate_util.h" +#import "ios/chrome/browser/browser_state/test_chrome_browser_state.h" +#import "ios/chrome/browser/infobars/infobar_manager_impl.h" +#import "ios/chrome/browser/language/language_model_manager_factory.h" +#import "ios/chrome/browser/optimization_guide/optimization_guide_service.h" +#import "ios/chrome/browser/optimization_guide/optimization_guide_service_factory.h" +#import "ios/chrome/browser/translate/language_detection_model_service_factory.h" +#import "ios/chrome/browser/translate/translate_model_service_factory.h" +#import "ios/chrome/browser/translate/translate_ranker_factory.h" +#import "ios/web/public/test/fakes/fake_navigation_context.h" +#import "ios/web/public/test/fakes/fake_navigation_manager.h" +#import "ios/web/public/test/fakes/fake_web_state.h" +#import "testing/platform_test.h" + +#if !defined(__has_feature) || !__has_feature(objc_arc) +#error "This file requires ARC support." +#endif + +class ChromeIOSTranslateClientTest : public PlatformTest { + public: + void SetUp() override { + PlatformTest::SetUp(); + scoped_feature_list_.InitWithFeatures( + {translate::kTFLiteLanguageDetectionEnabled}, {}); + TestChromeBrowserState::Builder builder; + builder.AddTestingFactory( + OptimizationGuideServiceFactory::GetInstance(), + OptimizationGuideServiceFactory::GetDefaultFactory()); + + browser_state_ = builder.Build(); + + web_state_.SetNavigationManager( + std::make_unique<web::FakeNavigationManager>()); + web_state_.SetBrowserState(browser_state_.get()); + language::IOSLanguageDetectionTabHelper::CreateForWebState( + &web_state_, /*url_language_histogram=*/nullptr); + ChromeIOSTranslateClient::CreateForWebState(&web_state_); + InfoBarManagerImpl::CreateForWebState(&web_state_); + } + + protected: + base::test::TaskEnvironment task_environment_; + base::test::ScopedFeatureList scoped_feature_list_; + base::HistogramTester histogram_tester_; + std::unique_ptr<TestChromeBrowserState> browser_state_; + web::FakeWebState web_state_; +}; + +TEST_F(ChromeIOSTranslateClientTest, TranslateUICreated) { + ChromeIOSTranslateClient* translate_client = + ChromeIOSTranslateClient::FromWebState(&web_state_); + translate_client->ShowTranslateUI(translate::TRANSLATE_STEP_AFTER_TRANSLATE, + "en", "en", + translate::TranslateErrors::NONE, + /*triggered_from_menu=*/false); + EXPECT_EQ(1U, InfoBarManagerImpl::FromWebState(&web_state_)->infobar_count()); +}
diff --git a/ios/chrome/browser/ui/browser_view/browser_view_controller.mm b/ios/chrome/browser/ui/browser_view/browser_view_controller.mm index 2564e8a..76017fc 100644 --- a/ios/chrome/browser/ui/browser_view/browser_view_controller.mm +++ b/ios/chrome/browser/ui/browser_view/browser_view_controller.mm
@@ -4019,7 +4019,7 @@ NewTabPageTabHelper::FromWebState(self.currentWebState); if (NTPHelper) { NTPHelper->SetNextNTPFeedType(feedType); - NTPHelper->SetNextNTPScrolledToFeed(YES); + // TODO(crbug.com/1329173): Scroll into feed. } // Navigate to NTP in same tab.
diff --git a/ios/chrome/browser/ui/fullscreen/fullscreen_egtest.mm b/ios/chrome/browser/ui/fullscreen/fullscreen_egtest.mm index 6db5a25..f650d01 100644 --- a/ios/chrome/browser/ui/fullscreen/fullscreen_egtest.mm +++ b/ios/chrome/browser/ui/fullscreen/fullscreen_egtest.mm
@@ -341,7 +341,8 @@ // Tests that the header is shown when a native page is loaded from a page where // the header was not seen before. -- (void)testShowHeaderOnNativePageLoad { +// TODO(crbug.com/1329265): failing on waterfall +- (void)DISABLED_testShowHeaderOnNativePageLoad { std::map<GURL, std::string> responses; const GURL URL = web::test::HttpServer::MakeUrl("http://origin");
diff --git a/ios/google_internal/frameworks/chrome_internal_dynamic_framework.ios.zip.sha1 b/ios/google_internal/frameworks/chrome_internal_dynamic_framework.ios.zip.sha1 index d49ccf9..ddc9029 100644 --- a/ios/google_internal/frameworks/chrome_internal_dynamic_framework.ios.zip.sha1 +++ b/ios/google_internal/frameworks/chrome_internal_dynamic_framework.ios.zip.sha1
@@ -1 +1 @@ -3290366da8784b928acecd735d7bfaf9d07ff1ac \ No newline at end of file +ecee08b6cd6288f1e7940fbf97faa294da4fd527 \ No newline at end of file
diff --git a/ios/google_internal/frameworks/chrome_internal_dynamic_framework.iossimulator.zip.sha1 b/ios/google_internal/frameworks/chrome_internal_dynamic_framework.iossimulator.zip.sha1 index 4598566..a92070ed 100644 --- a/ios/google_internal/frameworks/chrome_internal_dynamic_framework.iossimulator.zip.sha1 +++ b/ios/google_internal/frameworks/chrome_internal_dynamic_framework.iossimulator.zip.sha1
@@ -1 +1 @@ -b4cfbc547237508e05190e9134dcfdce03a601de \ No newline at end of file +286510957af70cb404f1ffe323004fa88967110c \ No newline at end of file
diff --git a/ios/google_internal/frameworks/chrome_sso_internal_dynamic_framework.ios.zip.sha1 b/ios/google_internal/frameworks/chrome_sso_internal_dynamic_framework.ios.zip.sha1 index 0c216e10..feeec4b3 100644 --- a/ios/google_internal/frameworks/chrome_sso_internal_dynamic_framework.ios.zip.sha1 +++ b/ios/google_internal/frameworks/chrome_sso_internal_dynamic_framework.ios.zip.sha1
@@ -1 +1 @@ -c31cd4e9d7ba630f291de981c91115bb5f761287 \ No newline at end of file +eb60cdce6623cfb146979fbc5591b0c52267629c \ No newline at end of file
diff --git a/ios/google_internal/frameworks/chrome_sso_internal_dynamic_framework.iossimulator.zip.sha1 b/ios/google_internal/frameworks/chrome_sso_internal_dynamic_framework.iossimulator.zip.sha1 index cf9852d..7bcb08c04 100644 --- a/ios/google_internal/frameworks/chrome_sso_internal_dynamic_framework.iossimulator.zip.sha1 +++ b/ios/google_internal/frameworks/chrome_sso_internal_dynamic_framework.iossimulator.zip.sha1
@@ -1 +1 @@ -8e7293333d2737fb092b44ccd98e2437b906f3ed \ No newline at end of file +c86b74bbe6adc5eb5b371137775c326393b00b70 \ No newline at end of file
diff --git a/ios/google_internal/frameworks/remoting_dogfood_internal_dynamic_framework.ios.zip.sha1 b/ios/google_internal/frameworks/remoting_dogfood_internal_dynamic_framework.ios.zip.sha1 index 2e61db0..aa51b49 100644 --- a/ios/google_internal/frameworks/remoting_dogfood_internal_dynamic_framework.ios.zip.sha1 +++ b/ios/google_internal/frameworks/remoting_dogfood_internal_dynamic_framework.ios.zip.sha1
@@ -1 +1 @@ -c95df183a2c63a3274d38d66ef7b6a37abc52f0b \ No newline at end of file +bfbc7e0daeae1faae5b88b58dad4d239300a88cb \ No newline at end of file
diff --git a/ios/google_internal/frameworks/remoting_dogfood_internal_dynamic_framework.iossimulator.zip.sha1 b/ios/google_internal/frameworks/remoting_dogfood_internal_dynamic_framework.iossimulator.zip.sha1 index da29d810..3fbf0de6 100644 --- a/ios/google_internal/frameworks/remoting_dogfood_internal_dynamic_framework.iossimulator.zip.sha1 +++ b/ios/google_internal/frameworks/remoting_dogfood_internal_dynamic_framework.iossimulator.zip.sha1
@@ -1 +1 @@ -d994c69cca729f5d2123169f4ef9ac1d510aaec4 \ No newline at end of file +66e016576ac95a6023b080d4bdb597bcd6a0c74e \ No newline at end of file
diff --git a/ios/google_internal/frameworks/remoting_internal_dynamic_framework.ios.zip.sha1 b/ios/google_internal/frameworks/remoting_internal_dynamic_framework.ios.zip.sha1 index 7278d01e..dcbcbf4 100644 --- a/ios/google_internal/frameworks/remoting_internal_dynamic_framework.ios.zip.sha1 +++ b/ios/google_internal/frameworks/remoting_internal_dynamic_framework.ios.zip.sha1
@@ -1 +1 @@ -b5a28ea37fdf7edd050b7f42047cc44a8acb7755 \ No newline at end of file +7524a8c5fcd0e2e2510e49e489013e2ee180636b \ No newline at end of file
diff --git a/ios/google_internal/frameworks/remoting_internal_dynamic_framework.iossimulator.zip.sha1 b/ios/google_internal/frameworks/remoting_internal_dynamic_framework.iossimulator.zip.sha1 index ae7d7c2..6a5f1e9 100644 --- a/ios/google_internal/frameworks/remoting_internal_dynamic_framework.iossimulator.zip.sha1 +++ b/ios/google_internal/frameworks/remoting_internal_dynamic_framework.iossimulator.zip.sha1
@@ -1 +1 @@ -5ef2537f6ed0944520e77ba4cdf0be8359e5b578 \ No newline at end of file +53f98b1f7824dfc9b6e3b379c6bd74c91cfdbb45 \ No newline at end of file
diff --git a/ios/google_internal/frameworks/web_view_shell_internal_dynamic_framework.ios.zip.sha1 b/ios/google_internal/frameworks/web_view_shell_internal_dynamic_framework.ios.zip.sha1 index cc45ee8..a92e21ab 100644 --- a/ios/google_internal/frameworks/web_view_shell_internal_dynamic_framework.ios.zip.sha1 +++ b/ios/google_internal/frameworks/web_view_shell_internal_dynamic_framework.ios.zip.sha1
@@ -1 +1 @@ -5451dc662f5e924ca3646b46114f2ed28f1942ce \ No newline at end of file +197a31d00315a5b689a0d66e26e3b8b0fbcb6c60 \ No newline at end of file
diff --git a/ios/google_internal/frameworks/web_view_shell_internal_dynamic_framework.iossimulator.zip.sha1 b/ios/google_internal/frameworks/web_view_shell_internal_dynamic_framework.iossimulator.zip.sha1 index a54c3746..77fe712 100644 --- a/ios/google_internal/frameworks/web_view_shell_internal_dynamic_framework.iossimulator.zip.sha1 +++ b/ios/google_internal/frameworks/web_view_shell_internal_dynamic_framework.iossimulator.zip.sha1
@@ -1 +1 @@ -da0d93e07d824c2f86f36458c442d8dbc1bd22ec \ No newline at end of file +1a48165be2f77d0782d31334ff584da0b463df5b \ No newline at end of file
diff --git a/ios/web/security/crw_cert_verification_controller_unittest.mm b/ios/web/security/crw_cert_verification_controller_unittest.mm index 4d8852cc..a32a788 100644 --- a/ios/web/security/crw_cert_verification_controller_unittest.mm +++ b/ios/web/security/crw_cert_verification_controller_unittest.mm
@@ -11,7 +11,7 @@ #import "ios/web/security/wk_web_view_security_util.h" #include "net/cert/x509_certificate.h" #include "net/cert/x509_util.h" -#include "net/cert/x509_util_ios_and_mac.h" +#include "net/cert/x509_util_apple.h" #include "net/test/cert_test_util.h" #include "net/test/test_data_directory.h"
diff --git a/ios/web/security/crw_ssl_status_updater_unittest.mm b/ios/web/security/crw_ssl_status_updater_unittest.mm index 3f4bd667..6631ebd 100644 --- a/ios/web/security/crw_ssl_status_updater_unittest.mm +++ b/ios/web/security/crw_ssl_status_updater_unittest.mm
@@ -15,7 +15,7 @@ #import "ios/web/security/wk_web_view_security_util.h" #import "ios/web/test/fakes/crw_fake_back_forward_list.h" #import "ios/web/test/fakes/fake_navigation_manager_delegate.h" -#include "net/cert/x509_util_ios_and_mac.h" +#include "net/cert/x509_util_apple.h" #include "net/test/cert_test_util.h" #include "net/test/test_data_directory.h" #include "third_party/ocmock/OCMock/OCMock.h"
diff --git a/ios/web/security/wk_web_view_security_util_unittest.mm b/ios/web/security/wk_web_view_security_util_unittest.mm index cbc018dd..ef067ba 100644 --- a/ios/web/security/wk_web_view_security_util_unittest.mm +++ b/ios/web/security/wk_web_view_security_util_unittest.mm
@@ -14,7 +14,7 @@ #include "crypto/rsa_private_key.h" #include "net/cert/x509_certificate.h" #include "net/cert/x509_util.h" -#include "net/cert/x509_util_ios_and_mac.h" +#include "net/cert/x509_util_apple.h" #include "net/ssl/ssl_info.h" #include "testing/gtest/include/gtest/gtest.h" #import "testing/gtest_mac.h"
diff --git a/ios/web/web_state/ui/crw_web_controller_unittest.mm b/ios/web/web_state/ui/crw_web_controller_unittest.mm index dce8f534..969873c63 100644 --- a/ios/web/web_state/ui/crw_web_controller_unittest.mm +++ b/ios/web/web_state/ui/crw_web_controller_unittest.mm
@@ -52,7 +52,7 @@ #import "ios/web/web_state/ui/crw_web_controller_container_view.h" #import "ios/web/web_state/web_state_impl.h" #import "net/base/mac/url_conversions.h" -#include "net/cert/x509_util_ios_and_mac.h" +#include "net/cert/x509_util_apple.h" #include "net/ssl/ssl_info.h" #include "net/test/cert_test_util.h" #include "net/test/test_data_directory.h"
diff --git a/media/cast/BUILD.gn b/media/cast/BUILD.gn index 6e6d0f2..6335908 100644 --- a/media/cast/BUILD.gn +++ b/media/cast/BUILD.gn
@@ -262,10 +262,18 @@ "test/loopback_transport.h", "test/mock_cast_transport.cc", "test/mock_cast_transport.h", + "test/mock_paced_packet_sender.cc", + "test/mock_paced_packet_sender.h", + "test/mock_rtp_payload_feedback.cc", + "test/mock_rtp_payload_feedback.h", + "test/rtp_packet_builder.cc", + "test/rtp_packet_builder.h", "test/skewed_single_thread_task_runner.cc", "test/skewed_single_thread_task_runner.h", "test/skewed_tick_clock.cc", "test/skewed_tick_clock.h", + "test/test_rtcp_packet_builder.cc", + "test/test_rtcp_packet_builder.h", "test/utility/audio_utility.cc", "test/utility/audio_utility.h", "test/utility/barcode.cc", @@ -319,6 +327,7 @@ sources = [ "common/expanded_value_base_unittest.cc", "common/rtp_time_unittest.cc", + "encoding/audio_encoder_unittest.cc", "encoding/external_video_encoder_unittest.cc", "encoding/video_encoder_unittest.cc", "encoding/vpx_quantizer_parser_unittest.cc", @@ -327,24 +336,12 @@ "logging/simple_event_subscriber_unittest.cc", "logging/stats_event_subscriber_unittest.cc", "net/cast_transport_impl_unittest.cc", - "net/pacing/mock_paced_packet_sender.cc", - "net/pacing/mock_paced_packet_sender.h", "net/pacing/paced_sender_unittest.cc", "net/rtcp/receiver_rtcp_event_subscriber_unittest.cc", "net/rtcp/rtcp_builder_unittest.cc", "net/rtcp/rtcp_unittest.cc", "net/rtcp/rtcp_utility_unittest.cc", - - # TODO(jophba): The following two are test utility modules. Rename/move the - # files. - "encoding/audio_encoder_unittest.cc", - "net/rtcp/test_rtcp_packet_builder.cc", - "net/rtcp/test_rtcp_packet_builder.h", - "net/rtp/mock_rtp_payload_feedback.cc", - "net/rtp/mock_rtp_payload_feedback.h", "net/rtp/packet_storage_unittest.cc", - "net/rtp/rtp_packet_builder.cc", - "net/rtp/rtp_packet_builder.h", "net/rtp/rtp_packetizer_unittest.cc", "net/rtp/rtp_parser_unittest.cc", "net/udp_packet_pipe_unittest.cc",
diff --git a/media/cast/net/rtcp/rtcp_builder_unittest.cc b/media/cast/net/rtcp/rtcp_builder_unittest.cc index 51d668d..3ec935c 100644 --- a/media/cast/net/rtcp/rtcp_builder_unittest.cc +++ b/media/cast/net/rtcp/rtcp_builder_unittest.cc
@@ -16,7 +16,7 @@ #include "media/cast/net/pacing/paced_sender.h" #include "media/cast/net/rtcp/receiver_rtcp_event_subscriber.h" #include "media/cast/net/rtcp/rtcp_utility.h" -#include "media/cast/net/rtcp/test_rtcp_packet_builder.h" +#include "media/cast/test/test_rtcp_packet_builder.h" #include "testing/gmock/include/gmock/gmock.h" namespace media {
diff --git a/media/cast/net/rtcp/rtcp_utility_unittest.cc b/media/cast/net/rtcp/rtcp_utility_unittest.cc index ca1d3edc..fa186a5 100644 --- a/media/cast/net/rtcp/rtcp_utility_unittest.cc +++ b/media/cast/net/rtcp/rtcp_utility_unittest.cc
@@ -12,7 +12,7 @@ #include "media/base/fake_single_thread_task_runner.h" #include "media/cast/cast_environment.h" #include "media/cast/net/cast_transport_defines.h" -#include "media/cast/net/rtcp/test_rtcp_packet_builder.h" +#include "media/cast/test/test_rtcp_packet_builder.h" #include "testing/gtest/include/gtest/gtest.h" namespace media {
diff --git a/media/cast/net/rtp/mock_rtp_feedback.h b/media/cast/net/rtp/mock_rtp_feedback.h deleted file mode 100644 index 0bb21b8b..0000000 --- a/media/cast/net/rtp/mock_rtp_feedback.h +++ /dev/null
@@ -1,38 +0,0 @@ -// Copyright 2014 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef MEDIA_CAST_NET_RTP_MOCK_RTP_FEEDBACK_H_ -#define MEDIA_CAST_NET_RTP_MOCK_RTP_FEEDBACK_H_ - -#include <stdint.h> - -#include "media/cast/net/rtp/rtp_parser/rtp_feedback.h" -#include "testing/gmock/include/gmock/gmock.h" - -namespace media { -namespace cast { - -class MockRtpFeedback : public RtpFeedback { - public: - MOCK_METHOD4(OnInitializeDecoder, - int32_t(const int8_t payloadType, - const int frequency, - const uint8_t channels, - const uint32_t rate)); - - MOCK_METHOD1(OnPacketTimeout, void(const int32_t id)); - MOCK_METHOD2(OnReceivedPacket, - void(const int32_t id, const RtpRtcpPacketField packet_type)); - MOCK_METHOD2(OnPeriodicDeadOrAlive, - void(const int32_t id, const RTPAliveType alive)); - MOCK_METHOD2(OnIncomingSSRCChanged, - void(const int32_t id, const uint32_t ssrc)); - MOCK_METHOD3(OnIncomingCSRCChanged, - void(const int32_t id, const uint32_t csrc, const bool added)); -}; - -} // namespace cast -} // namespace media - -#endif // MEDIA_CAST_NET_RTP_MOCK_RTP_FEEDBACK_H_
diff --git a/media/cast/net/rtp/rtp_parser_unittest.cc b/media/cast/net/rtp/rtp_parser_unittest.cc index a832e0c..ce3272b 100644 --- a/media/cast/net/rtp/rtp_parser_unittest.cc +++ b/media/cast/net/rtp/rtp_parser_unittest.cc
@@ -11,7 +11,7 @@ #include "base/rand_util.h" #include "media/cast/net/rtp/rtp_defines.h" -#include "media/cast/net/rtp/rtp_packet_builder.h" +#include "media/cast/test/rtp_packet_builder.h" #include "testing/gtest/include/gtest/gtest.h" namespace media {
diff --git a/media/cast/net/pacing/mock_paced_packet_sender.cc b/media/cast/test/mock_paced_packet_sender.cc similarity index 85% rename from media/cast/net/pacing/mock_paced_packet_sender.cc rename to media/cast/test/mock_paced_packet_sender.cc index 16e396fd..623a396 100644 --- a/media/cast/net/pacing/mock_paced_packet_sender.cc +++ b/media/cast/test/mock_paced_packet_sender.cc
@@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#include "media/cast/net/pacing/mock_paced_packet_sender.h" +#include "media/cast/test/mock_paced_packet_sender.h" namespace media { namespace cast {
diff --git a/media/cast/net/pacing/mock_paced_packet_sender.h b/media/cast/test/mock_paced_packet_sender.h similarity index 68% rename from media/cast/net/pacing/mock_paced_packet_sender.h rename to media/cast/test/mock_paced_packet_sender.h index 0193ce8..4baabad 100644 --- a/media/cast/net/pacing/mock_paced_packet_sender.h +++ b/media/cast/test/mock_paced_packet_sender.h
@@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#ifndef MEDIA_CAST_NET_PACING_MOCK_PACED_PACKET_SENDER_H_ -#define MEDIA_CAST_NET_PACING_MOCK_PACED_PACKET_SENDER_H_ +#ifndef MEDIA_CAST_TEST_MOCK_PACED_PACKET_SENDER_H_ +#define MEDIA_CAST_TEST_MOCK_PACED_PACKET_SENDER_H_ #include "media/cast/net/pacing/paced_sender.h" #include "testing/gmock/include/gmock/gmock.h" @@ -17,8 +17,9 @@ ~MockPacedPacketSender() override; MOCK_METHOD1(SendPackets, bool(const SendPacketVector& packets)); - MOCK_METHOD2(ResendPackets, bool(const SendPacketVector& packets, - const DedupInfo& dedup_info)); + MOCK_METHOD2(ResendPackets, + bool(const SendPacketVector& packets, + const DedupInfo& dedup_info)); MOCK_METHOD2(SendRtcpPacket, bool(unsigned int ssrc, PacketRef packet)); MOCK_METHOD1(CancelSendingPacket, void(const PacketKey& packet_key)); }; @@ -26,4 +27,4 @@ } // namespace cast } // namespace media -#endif // MEDIA_CAST_NET_PACING_MOCK_PACED_PACKET_SENDER_H_ +#endif // MEDIA_CAST_TEST_MOCK_PACED_PACKET_SENDER_H_
diff --git a/media/cast/net/rtp/mock_rtp_payload_feedback.cc b/media/cast/test/mock_rtp_payload_feedback.cc similarity index 86% rename from media/cast/net/rtp/mock_rtp_payload_feedback.cc rename to media/cast/test/mock_rtp_payload_feedback.cc index fc87b1cf..52b585b 100644 --- a/media/cast/net/rtp/mock_rtp_payload_feedback.cc +++ b/media/cast/test/mock_rtp_payload_feedback.cc
@@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#include "media/cast/net/rtp/mock_rtp_payload_feedback.h" +#include "media/cast/test/mock_rtp_payload_feedback.h" namespace media { namespace cast {
diff --git a/media/cast/net/rtp/mock_rtp_payload_feedback.h b/media/cast/test/mock_rtp_payload_feedback.h similarity index 76% rename from media/cast/net/rtp/mock_rtp_payload_feedback.h rename to media/cast/test/mock_rtp_payload_feedback.h index 90c0943..7d810f78 100644 --- a/media/cast/net/rtp/mock_rtp_payload_feedback.h +++ b/media/cast/test/mock_rtp_payload_feedback.h
@@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#ifndef MEDIA_CAST_NET_RTP_MOCK_RTP_PAYLOAD_FEEDBACK_H_ -#define MEDIA_CAST_NET_RTP_MOCK_RTP_PAYLOAD_FEEDBACK_H_ +#ifndef MEDIA_CAST_TEST_MOCK_RTP_PAYLOAD_FEEDBACK_H_ +#define MEDIA_CAST_TEST_MOCK_RTP_PAYLOAD_FEEDBACK_H_ #include "media/cast/net/rtp/rtp_defines.h" #include "testing/gmock/include/gmock/gmock.h" @@ -22,4 +22,4 @@ } // namespace cast } // namespace media -#endif // MEDIA_CAST_NET_RTP_MOCK_RTP_PAYLOAD_FEEDBACK_H_ +#endif // MEDIA_CAST_TEST_MOCK_RTP_PAYLOAD_FEEDBACK_H_
diff --git a/media/cast/test/receiver/frame_receiver_unittest.cc b/media/cast/test/receiver/frame_receiver_unittest.cc index e0881a0..e818bb3 100644 --- a/media/cast/test/receiver/frame_receiver_unittest.cc +++ b/media/cast/test/receiver/frame_receiver_unittest.cc
@@ -21,8 +21,8 @@ #include "media/cast/logging/simple_event_subscriber.h" #include "media/cast/net/cast_transport_impl.h" #include "media/cast/net/rtcp/rtcp_utility.h" -#include "media/cast/net/rtcp/test_rtcp_packet_builder.h" #include "media/cast/test/mock_cast_transport.h" +#include "media/cast/test/test_rtcp_packet_builder.h" #include "media/cast/test/utility/default_config.h" #include "testing/gmock/include/gmock/gmock.h"
diff --git a/media/cast/test/receiver/framer_unittest.cc b/media/cast/test/receiver/framer_unittest.cc index 6d922c0a..6c032b6 100644 --- a/media/cast/test/receiver/framer_unittest.cc +++ b/media/cast/test/receiver/framer_unittest.cc
@@ -7,7 +7,7 @@ #include "base/test/simple_test_tick_clock.h" #include "media/cast/common/encoded_frame.h" #include "media/cast/net/cast_transport_defines.h" -#include "media/cast/net/rtp/mock_rtp_payload_feedback.h" +#include "media/cast/test/mock_rtp_payload_feedback.h" #include "media/cast/test/receiver/framer.h" #include "testing/gtest/include/gtest/gtest.h"
diff --git a/media/cast/net/rtp/rtp_packet_builder.cc b/media/cast/test/rtp_packet_builder.cc similarity index 93% rename from media/cast/net/rtp/rtp_packet_builder.cc rename to media/cast/test/rtp_packet_builder.cc index c49c718..d771923 100644 --- a/media/cast/net/rtp/rtp_packet_builder.cc +++ b/media/cast/test/rtp_packet_builder.cc
@@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#include "media/cast/net/rtp/rtp_packet_builder.h" +#include "media/cast/test/rtp_packet_builder.h" #include "base/big_endian.h" #include "base/check_op.h" @@ -22,7 +22,9 @@ payload_type_(0), ssrc_(0) {} -void RtpPacketBuilder::SetKeyFrame(bool is_key) { is_key_ = is_key; } +void RtpPacketBuilder::SetKeyFrame(bool is_key) { + is_key_ = is_key; +} void RtpPacketBuilder::SetFrameIds(uint32_t frame_id, uint32_t reference_frame_id) { @@ -46,7 +48,9 @@ sequence_number_ = sequence_number; } -void RtpPacketBuilder::SetMarkerBit(bool marker) { marker_ = marker; } +void RtpPacketBuilder::SetMarkerBit(bool marker) { + marker_ = marker; +} void RtpPacketBuilder::SetPayloadType(int payload_type) { payload_type_ = payload_type;
diff --git a/media/cast/net/rtp/rtp_packet_builder.h b/media/cast/test/rtp_packet_builder.h similarity index 89% rename from media/cast/net/rtp/rtp_packet_builder.h rename to media/cast/test/rtp_packet_builder.h index 67821f1..5ef002e 100644 --- a/media/cast/net/rtp/rtp_packet_builder.h +++ b/media/cast/test/rtp_packet_builder.h
@@ -4,8 +4,8 @@ // Test helper class that builds rtp packets. -#ifndef MEDIA_CAST_NET_RTP_RTP_PACKET_BUILDER_H_ -#define MEDIA_CAST_NET_RTP_RTP_PACKET_BUILDER_H_ +#ifndef MEDIA_CAST_TEST_RTP_PACKET_BUILDER_H_ +#define MEDIA_CAST_TEST_RTP_PACKET_BUILDER_H_ #include <stdint.h> @@ -51,4 +51,4 @@ } // namespace cast } // namespace media -#endif // MEDIA_CAST_NET_RTP_RTP_PACKET_BUILDER_H_ +#endif // MEDIA_CAST_TEST_RTP_PACKET_BUILDER_H_
diff --git a/media/cast/net/rtcp/test_rtcp_packet_builder.cc b/media/cast/test/test_rtcp_packet_builder.cc similarity index 98% rename from media/cast/net/rtcp/test_rtcp_packet_builder.cc rename to media/cast/test/test_rtcp_packet_builder.cc index 57a02034..34f0136f 100644 --- a/media/cast/net/rtcp/test_rtcp_packet_builder.cc +++ b/media/cast/test/test_rtcp_packet_builder.cc
@@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#include "media/cast/net/rtcp/test_rtcp_packet_builder.h" +#include "media/cast/test/test_rtcp_packet_builder.h" #include <memory> @@ -141,7 +141,7 @@ big_endian_writer_.WriteU8('S'); big_endian_writer_.WriteU8('T'); big_endian_writer_.WriteU8(kAckFrameId); - big_endian_writer_.WriteU8(3); // Loss fields. + big_endian_writer_.WriteU8(3); // Loss fields. big_endian_writer_.WriteU16(target_delay.InMilliseconds()); big_endian_writer_.WriteU8(kLostFrameId); big_endian_writer_.WriteU16(kRtcpCastAllPacketsLost);
diff --git a/media/cast/net/rtcp/test_rtcp_packet_builder.h b/media/cast/test/test_rtcp_packet_builder.h similarity index 95% rename from media/cast/net/rtcp/test_rtcp_packet_builder.h rename to media/cast/test/test_rtcp_packet_builder.h index 031f341..2f76810 100644 --- a/media/cast/net/rtcp/test_rtcp_packet_builder.h +++ b/media/cast/test/test_rtcp_packet_builder.h
@@ -4,8 +4,8 @@ // A very simple packet builder class for building RTCP packets. // Used for testing only. -#ifndef MEDIA_CAST_NET_RTCP_TEST_RTCP_PACKET_BUILDER_H_ -#define MEDIA_CAST_NET_RTCP_TEST_RTCP_PACKET_BUILDER_H_ +#ifndef MEDIA_CAST_TEST_TEST_RTCP_PACKET_BUILDER_H_ +#define MEDIA_CAST_TEST_TEST_RTCP_PACKET_BUILDER_H_ #include <stdint.h> #include <vector> @@ -112,4 +112,4 @@ } // namespace cast } // namespace media -#endif // MEDIA_CAST_NET_RTCP_TEST_RTCP_PACKET_BUILDER_H_ +#endif // MEDIA_CAST_TEST_TEST_RTCP_PACKET_BUILDER_H_
diff --git a/media/gpu/test/OWNERS b/media/gpu/test/OWNERS index 8efbdde4..6ca0a43d2 100644 --- a/media/gpu/test/OWNERS +++ b/media/gpu/test/OWNERS
@@ -1,2 +1 @@ -dstaessens@chromium.org hiroh@chromium.org
diff --git a/net/BUILD.gn b/net/BUILD.gn index 5cdb3af..3dde3258 100644 --- a/net/BUILD.gn +++ b/net/BUILD.gn
@@ -1238,8 +1238,8 @@ "base/platform_mime_util_mac.mm", "base/proxy_string_util_mac.cc", "cert/test_root_certs_mac.cc", - "cert/x509_util_ios_and_mac.cc", - "cert/x509_util_ios_and_mac.h", + "cert/x509_util_apple.cc", + "cert/x509_util_apple.h", "proxy_resolution/proxy_resolver_mac.cc", "proxy_resolution/proxy_resolver_mac.h", ] @@ -4388,12 +4388,15 @@ sources += [ "cert/cert_verify_proc_mac_unittest.cc", "cert/internal/trust_store_mac_unittest.cc", - "cert/x509_util_ios_and_mac_unittest.cc", "ssl/client_cert_store_mac_unittest.cc", "ssl/ssl_platform_key_mac_unittest.cc", ] } + if (is_apple) { + sources += [ "cert/x509_util_apple_unittest.cc" ] + } + if (is_win) { sources += [ "base/network_change_notifier_win_unittest.cc", @@ -4652,7 +4655,6 @@ "url_request/url_fetcher_impl_unittest.cc", "url_request/url_request_context_builder_unittest.cc", ] - sources += [ "cert/x509_util_ios_and_mac_unittest.cc" ] bundle_deps = [ ":net_unittests_bundle_data" ] }
diff --git a/net/cert/cert_verify_proc.cc b/net/cert/cert_verify_proc.cc index dcdbfb6..f10ab40 100644 --- a/net/cert/cert_verify_proc.cc +++ b/net/cert/cert_verify_proc.cc
@@ -495,23 +495,22 @@ int flags, CRLSet* crl_set, const CertificateList& additional_trust_anchors) { - base::Value dict(base::Value::Type::DICTIONARY); - dict.SetKey("certificates", NetLogX509CertificateList(cert)); + base::Value::Dict dict; + dict.Set("certificates", NetLogX509CertificateList(cert)); if (!ocsp_response.empty()) { - dict.SetStringKey("ocsp_response", - PEMEncode(ocsp_response, "NETLOG OCSP RESPONSE")); + dict.Set("ocsp_response", PEMEncode(ocsp_response, "NETLOG OCSP RESPONSE")); } if (!sct_list.empty()) { - dict.SetStringKey("sct_list", PEMEncode(sct_list, "NETLOG SCT LIST")); + dict.Set("sct_list", PEMEncode(sct_list, "NETLOG SCT LIST")); } - dict.SetKey("host", NetLogStringValue(hostname)); - dict.SetIntKey("verify_flags", flags); - dict.SetKey("crlset_sequence", NetLogNumberValue(crl_set->sequence())); + dict.Set("host", NetLogStringValue(hostname)); + dict.Set("verify_flags", flags); + dict.Set("crlset_sequence", NetLogNumberValue(crl_set->sequence())); if (crl_set->IsExpired()) - dict.SetBoolKey("crlset_is_expired", true); + dict.Set("crlset_is_expired", true); if (!additional_trust_anchors.empty()) { - base::Value certs(base::Value::Type::LIST); + base::Value::List certs; for (auto& anchor : additional_trust_anchors) { std::string pem_encoded; if (X509Certificate::GetPEMEncodedFromDER( @@ -520,10 +519,10 @@ certs.Append(std::move(pem_encoded)); } } - dict.SetKey("additional_trust_anchors", std::move(certs)); + dict.Set("additional_trust_anchors", std::move(certs)); } - return dict; + return base::Value(std::move(dict)); } } // namespace
diff --git a/net/cert/cert_verify_proc_builtin.cc b/net/cert/cert_verify_proc_builtin.cc index 7d805d0..4dd10d1 100644 --- a/net/cert/cert_verify_proc_builtin.cc +++ b/net/cert/cert_verify_proc_builtin.cc
@@ -62,12 +62,12 @@ std::string pem_encoded; if (X509Certificate::GetPEMEncodedFromDER( x509_util::CryptoBufferAsStringPiece(cert_handle), &pem_encoded)) { - results.SetStringKey("certificate", pem_encoded); + results.GetDict().Set("certificate", pem_encoded); } std::string errors_string = errors.ToDebugString(); if (!errors_string.empty()) - results.SetStringKey("errors", errors_string); + results.GetDict().Set("errors", errors_string); return results; } @@ -78,37 +78,37 @@ std::string pem; X509Certificate::GetPEMEncodedFromDER(cert->der_cert().AsStringPiece(), &pem); - value.Append(std::move(pem)); + value.GetList().Append(std::move(pem)); } return value; } base::Value NetLogPathBuilderResultPath( const CertPathBuilderResultPath& result_path) { - base::Value value(base::Value::Type::DICTIONARY); - value.SetBoolKey("is_valid", result_path.IsValid()); - value.SetIntKey("last_cert_trust", - static_cast<int>(result_path.last_cert_trust.type)); - value.SetKey("certificates", PEMCertListValue(result_path.certs)); + base::Value::Dict dict; + dict.Set("is_valid", result_path.IsValid()); + dict.Set("last_cert_trust", + static_cast<int>(result_path.last_cert_trust.type)); + dict.Set("certificates", PEMCertListValue(result_path.certs)); // TODO(crbug.com/634484): netlog user_constrained_policy_set. std::string errors_string = result_path.errors.ToDebugString(result_path.certs); if (!errors_string.empty()) - value.SetStringKey("errors", errors_string); - return value; + dict.Set("errors", errors_string); + return base::Value(std::move(dict)); } base::Value NetLogPathBuilderResult(const CertPathBuilder::Result& result) { - base::Value value(base::Value::Type::DICTIONARY); + base::Value::Dict dict; // TODO(crbug.com/634484): include debug data (or just have things netlog it // directly). - value.SetBoolKey("has_valid_path", result.HasValidPath()); - value.SetIntKey("best_result_index", result.best_result_index); + dict.Set("has_valid_path", result.HasValidPath()); + dict.Set("best_result_index", static_cast<int>(result.best_result_index)); if (result.exceeded_iteration_limit) - value.SetBoolKey("exceeded_iteration_limit", true); + dict.Set("exceeded_iteration_limit", true); if (result.exceeded_deadline) - value.SetBoolKey("exceeded_deadline", true); - return value; + dict.Set("exceeded_deadline", true); + return base::Value(std::move(dict)); } RevocationPolicy NoRevocationChecking() { @@ -802,11 +802,11 @@ verification_type = cur_attempt.verification_type; net_log.BeginEvent( NetLogEventType::CERT_VERIFY_PROC_PATH_BUILD_ATTEMPT, [&] { - base::DictionaryValue results; + base::Value results(base::Value::Type::DICTIONARY); if (verification_type == VerificationType::kEV) - results.SetBoolKey("is_ev_attempt", true); - results.SetIntKey("digest_policy", - static_cast<int>(cur_attempt.digest_policy)); + results.GetDict().Set("is_ev_attempt", true); + results.GetDict().Set("digest_policy", + static_cast<int>(cur_attempt.digest_policy)); return results; });
diff --git a/net/cert/cert_verify_proc_ios.cc b/net/cert/cert_verify_proc_ios.cc index dc863a8..ff4b80d 100644 --- a/net/cert/cert_verify_proc_ios.cc +++ b/net/cert/cert_verify_proc_ios.cc
@@ -19,8 +19,8 @@ #include "net/cert/known_roots.h" #include "net/cert/test_root_certs.h" #include "net/cert/x509_certificate.h" +#include "net/cert/x509_util_apple.h" #include "net/cert/x509_util_ios.h" -#include "net/cert/x509_util_ios_and_mac.h" using base::ScopedCFTypeRef;
diff --git a/net/cert/cert_verify_proc_mac.cc b/net/cert/cert_verify_proc_mac.cc index a96d3ee..e9816a8 100644 --- a/net/cert/cert_verify_proc_mac.cc +++ b/net/cert/cert_verify_proc_mac.cc
@@ -37,7 +37,7 @@ #include "net/cert/test_root_certs.h" #include "net/cert/x509_certificate.h" #include "net/cert/x509_util.h" -#include "net/cert/x509_util_ios_and_mac.h" +#include "net/cert/x509_util_apple.h" #include "net/cert/x509_util_mac.h" // CSSM functions are deprecated as of OSX 10.7, but have no replacement.
diff --git a/net/cert/cert_verify_result.cc b/net/cert/cert_verify_result.cc index a176927..6126e36 100644 --- a/net/cert/cert_verify_result.cc +++ b/net/cert/cert_verify_result.cc
@@ -65,30 +65,30 @@ } base::Value CertVerifyResult::NetLogParams(int net_error) const { - base::DictionaryValue results; + base::Value::Dict dict; DCHECK_NE(ERR_IO_PENDING, net_error); if (net_error < 0) - results.SetIntKey("net_error", net_error); - results.SetBoolKey("is_issued_by_known_root", is_issued_by_known_root); + dict.Set("net_error", net_error); + dict.Set("is_issued_by_known_root", is_issued_by_known_root); if (is_issued_by_additional_trust_anchor) { - results.SetBoolKey("is_issued_by_additional_trust_anchor", true); + dict.Set("is_issued_by_additional_trust_anchor", true); } - results.SetIntKey("cert_status", cert_status); + dict.Set("cert_status", static_cast<int>(cert_status)); // TODO(mattm): This double-wrapping of the certificate list is weird. Remove // this (probably requires updates to netlog-viewer). - base::Value certificate_dict(base::Value::Type::DICTIONARY); - certificate_dict.SetKey("certificates", - net::NetLogX509CertificateList(verified_cert.get())); - results.SetKey("verified_cert", std::move(certificate_dict)); + base::Value::Dict certificate_dict; + certificate_dict.Set("certificates", + net::NetLogX509CertificateList(verified_cert.get())); + dict.Set("verified_cert", std::move(certificate_dict)); - base::Value hashes(base::Value::Type::LIST); + base::Value::List hashes; for (const auto& public_key_hash : public_key_hashes) hashes.Append(public_key_hash.ToString()); - results.SetKey("public_key_hashes", std::move(hashes)); + dict.Set("public_key_hashes", std::move(hashes)); - results.SetKey("scts", net::NetLogSignedCertificateTimestampParams(&scts)); + dict.Set("scts", net::NetLogSignedCertificateTimestampParams(&scts)); - return std::move(results); + return base::Value(std::move(dict)); } } // namespace net
diff --git a/net/cert/coalescing_cert_verifier.cc b/net/cert/coalescing_cert_verifier.cc index 372a50b..a7c003ee 100644 --- a/net/cert/coalescing_cert_verifier.cc +++ b/net/cert/coalescing_cert_verifier.cc
@@ -72,21 +72,20 @@ namespace { base::Value CertVerifierParams(const CertVerifier::RequestParams& params) { - base::Value dict(base::Value::Type::DICTIONARY); - dict.SetKey("certificates", - NetLogX509CertificateList(params.certificate().get())); + base::Value::Dict dict; + dict.Set("certificates", + NetLogX509CertificateList(params.certificate().get())); if (!params.ocsp_response().empty()) { - dict.SetStringPath("ocsp_response", PEMEncode(params.ocsp_response(), - "NETLOG OCSP RESPONSE")); + dict.Set("ocsp_response", + PEMEncode(params.ocsp_response(), "NETLOG OCSP RESPONSE")); } if (!params.sct_list().empty()) { - dict.SetStringPath("sct_list", - PEMEncode(params.sct_list(), "NETLOG SCT LIST")); + dict.Set("sct_list", PEMEncode(params.sct_list(), "NETLOG SCT LIST")); } - dict.SetPath("host", NetLogStringValue(params.hostname())); - dict.SetIntPath("verifier_flags", params.flags()); + dict.Set("host", NetLogStringValue(params.hostname())); + dict.Set("verifier_flags", params.flags()); - return dict; + return base::Value(std::move(dict)); } } // namespace
diff --git a/net/cert/crl_set.cc b/net/cert/crl_set.cc index 56056f7..2579030 100644 --- a/net/cert/crl_set.cc +++ b/net/cert/crl_set.cc
@@ -54,7 +54,7 @@ // ReadHeader reads the header (including length prefix) from |data| and // updates |data| to remove the header on return. Caller takes ownership of the // returned pointer. -base::DictionaryValue* ReadHeader(base::StringPiece* data) { +std::unique_ptr<base::Value> ReadHeader(base::StringPiece* data) { uint16_t header_len; if (data->size() < sizeof(header_len)) return nullptr; @@ -75,7 +75,7 @@ if (!header->is_dict()) return nullptr; - return static_cast<base::DictionaryValue*>(header.release()); + return header; } // kCurrentFileVersion is the version of the CRLSet file format that we @@ -124,22 +124,21 @@ // the given |key| (without path expansion) in |header_dict| and sets |*out| // to the decoded values. It's not an error if |key| is not found in // |header_dict|. -bool CopyHashListFromHeader(base::DictionaryValue* header_dict, +bool CopyHashListFromHeader(const base::Value::Dict& header_dict, const char* key, std::vector<std::string>* out) { - const base::Value* list = header_dict->FindListKey(key); + const base::Value::List* list = header_dict.FindList(key); if (!list) { // Hash lists are optional so it's not an error if not present. return true; } - base::Value::ConstListView list_view = list->GetListDeprecated(); out->clear(); - out->reserve(list_view.size()); + out->reserve(list->size()); std::string sha256_base64; - for (const base::Value& i : list_view) { + for (const base::Value& i : *list) { sha256_base64.clear(); if (!i.is_string()) @@ -160,25 +159,24 @@ // hashes to lists of the same, from the given |key| in |header_dict|. It // copies the map data into |out| (after base64-decoding). bool CopyHashToHashesMapFromHeader( - base::DictionaryValue* header_dict, + const base::Value::Dict& header_dict, const char* key, std::unordered_map<std::string, std::vector<std::string>>* out) { out->clear(); - base::Value* const dict = - header_dict->FindKeyOfType(key, base::Value::Type::DICTIONARY); + const base::Value::Dict* dict = header_dict.FindDict(key); if (dict == nullptr) { // Maps are optional so it's not an error if not present. return true; } - for (auto i : dict->DictItems()) { + for (auto i : *dict) { if (!i.second.is_list()) { return false; } std::vector<std::string> allowed_spkis; - for (const auto& j : i.second.GetListDeprecated()) { + for (const auto& j : i.second.GetList()) { allowed_spkis.push_back(std::string()); if (!j.is_string() || !base::Base64Decode(j.GetString(), &allowed_spkis.back())) { @@ -219,23 +217,25 @@ #error assumes little endian #endif - std::unique_ptr<base::DictionaryValue> header_dict(ReadHeader(&data)); - if (!header_dict.get()) + std::unique_ptr<base::Value> header_value(ReadHeader(&data)); + if (!header_value.get()) return false; - std::string* contents = header_dict->FindStringKey("ContentType"); + const base::Value::Dict& header_dict = header_value->GetDict(); + + const std::string* contents = header_dict.FindString("ContentType"); if (!contents || (*contents != "CRLSet")) return false; - if (header_dict->FindIntKey("Version") != kCurrentFileVersion) + if (header_dict.FindInt("Version") != kCurrentFileVersion) return false; - absl::optional<int> sequence = header_dict->FindIntKey("Sequence"); + absl::optional<int> sequence = header_dict.FindInt("Sequence"); if (!sequence) return false; // NotAfter is optional for now. - double not_after = header_dict->FindDoubleKey("NotAfter").value_or(0); + double not_after = header_dict.FindDouble("NotAfter").value_or(0); if (not_after < 0) return false; @@ -255,13 +255,13 @@ } std::vector<std::string> blocked_interception_spkis; - if (!CopyHashListFromHeader(header_dict.get(), "BlockedSPKIs", + if (!CopyHashListFromHeader(header_dict, "BlockedSPKIs", &crl_set->blocked_spkis_) || - !CopyHashToHashesMapFromHeader(header_dict.get(), "LimitedSubjects", + !CopyHashToHashesMapFromHeader(header_dict, "LimitedSubjects", &crl_set->limited_subjects_) || - !CopyHashListFromHeader(header_dict.get(), "KnownInterceptionSPKIs", + !CopyHashListFromHeader(header_dict, "KnownInterceptionSPKIs", &crl_set->known_interception_spkis_) || - !CopyHashListFromHeader(header_dict.get(), "BlockedInterceptionSPKIs", + !CopyHashListFromHeader(header_dict, "BlockedInterceptionSPKIs", &blocked_interception_spkis)) { return false; }
diff --git a/net/cert/ct_log_response_parser.cc b/net/cert/ct_log_response_parser.cc index 90a8a9a..1d92485 100644 --- a/net/cert/ct_log_response_parser.cc +++ b/net/cert/ct_log_response_parser.cc
@@ -115,9 +115,8 @@ return false; } - const base::DictionaryValue* dict_value = nullptr; - if (!json_consistency_proof.GetAsDictionary(&dict_value) || - !dict_value->FindKey("consistency")) { + const base::Value::Dict* dict_value = json_consistency_proof.GetIfDict(); + if (!dict_value || !dict_value->Find("consistency")) { return false; }
diff --git a/net/cert/ct_signed_certificate_timestamp_log_param.cc b/net/cert/ct_signed_certificate_timestamp_log_param.cc index b3f52dc0..1ada1a9b 100644 --- a/net/cert/ct_signed_certificate_timestamp_log_param.cc +++ b/net/cert/ct_signed_certificate_timestamp_log_param.cc
@@ -22,11 +22,11 @@ // description |key|. void SetBinaryData(const char* key, base::StringPiece value, - base::Value* dict) { + base::Value::Dict& dict) { std::string b64_value; base::Base64Encode(value, &b64_value); - dict->SetStringKey(key, b64_value); + dict.Set(key, b64_value); } // Returns a dictionary where each key is a field of the SCT and its value @@ -34,27 +34,26 @@ // outputting a de-serialized SCT to the NetLog. base::Value SCTToDictionary(const ct::SignedCertificateTimestamp& sct, ct::SCTVerifyStatus status) { - base::Value out(base::Value::Type::DICTIONARY); + base::Value::Dict dict; - out.SetStringKey("origin", OriginToString(sct.origin)); - out.SetStringKey("verification_status", StatusToString(status)); - out.SetIntKey("version", sct.version); + dict.Set("origin", OriginToString(sct.origin)); + dict.Set("verification_status", StatusToString(status)); + dict.Set("version", sct.version); - SetBinaryData("log_id", sct.log_id, &out); + SetBinaryData("log_id", sct.log_id, dict); base::TimeDelta time_since_unix_epoch = sct.timestamp - base::Time::UnixEpoch(); - out.SetStringKey("timestamp", base::NumberToString( - time_since_unix_epoch.InMilliseconds())); - SetBinaryData("extensions", sct.extensions, &out); + dict.Set("timestamp", + base::NumberToString(time_since_unix_epoch.InMilliseconds())); + SetBinaryData("extensions", sct.extensions, dict); - out.SetStringKey("hash_algorithm", - HashAlgorithmToString(sct.signature.hash_algorithm)); - out.SetStringKey( - "signature_algorithm", - SignatureAlgorithmToString(sct.signature.signature_algorithm)); - SetBinaryData("signature_data", sct.signature.signature_data, &out); + dict.Set("hash_algorithm", + HashAlgorithmToString(sct.signature.hash_algorithm)); + dict.Set("signature_algorithm", + SignatureAlgorithmToString(sct.signature.signature_algorithm)); + SetBinaryData("signature_data", sct.signature.signature_data, dict); - return out; + return base::Value(std::move(dict)); } // Given a list of SCTs and their statuses, return a list Value where each item @@ -62,9 +61,10 @@ base::Value SCTListToPrintableValues( const SignedCertificateTimestampAndStatusList& sct_and_status_list) { base::Value output_scts(base::Value::Type::LIST); - for (const auto& sct_and_status : sct_and_status_list) - output_scts.Append( + for (const auto& sct_and_status : sct_and_status_list) { + output_scts.GetList().Append( SCTToDictionary(*(sct_and_status.sct.get()), sct_and_status.status)); + } return output_scts; } @@ -75,7 +75,7 @@ const SignedCertificateTimestampAndStatusList* scts) { base::Value dict(base::Value::Type::DICTIONARY); - dict.SetKey("scts", SCTListToPrintableValues(*scts)); + dict.GetDict().Set("scts", SCTListToPrintableValues(*scts)); return dict; } @@ -84,13 +84,13 @@ base::StringPiece embedded_scts, base::StringPiece sct_list_from_ocsp, base::StringPiece sct_list_from_tls_extension) { - base::Value dict(base::Value::Type::DICTIONARY); + base::Value::Dict dict; - SetBinaryData("embedded_scts", embedded_scts, &dict); - SetBinaryData("scts_from_ocsp_response", sct_list_from_ocsp, &dict); - SetBinaryData("scts_from_tls_extension", sct_list_from_tls_extension, &dict); + SetBinaryData("embedded_scts", embedded_scts, dict); + SetBinaryData("scts_from_ocsp_response", sct_list_from_ocsp, dict); + SetBinaryData("scts_from_tls_extension", sct_list_from_tls_extension, dict); - return dict; + return base::Value(std::move(dict)); } } // namespace net
diff --git a/net/cert/internal/trust_store_mac.cc b/net/cert/internal/trust_store_mac.cc index 128c376..d6a54e7 100644 --- a/net/cert/internal/trust_store_mac.cc +++ b/net/cert/internal/trust_store_mac.cc
@@ -25,7 +25,7 @@ #include "net/cert/known_roots_mac.h" #include "net/cert/test_keychain_search_list_mac.h" #include "net/cert/x509_util.h" -#include "net/cert/x509_util_ios_and_mac.h" +#include "net/cert/x509_util_apple.h" #include "net/cert/x509_util_mac.h" #include "third_party/boringssl/src/include/openssl/sha.h"
diff --git a/net/cert/internal/trust_store_mac_unittest.cc b/net/cert/internal/trust_store_mac_unittest.cc index ddfe489..095b90d 100644 --- a/net/cert/internal/trust_store_mac_unittest.cc +++ b/net/cert/internal/trust_store_mac_unittest.cc
@@ -20,7 +20,7 @@ #include "net/cert/test_keychain_search_list_mac.h" #include "net/cert/x509_certificate.h" #include "net/cert/x509_util.h" -#include "net/cert/x509_util_ios_and_mac.h" +#include "net/cert/x509_util_apple.h" #include "net/test/test_data_directory.h" #include "testing/gmock/include/gmock/gmock.h" #include "testing/gtest/include/gtest/gtest.h"
diff --git a/net/cert/test_root_certs_mac.cc b/net/cert/test_root_certs_mac.cc index 54ccc38..f57ad926 100644 --- a/net/cert/test_root_certs_mac.cc +++ b/net/cert/test_root_certs_mac.cc
@@ -10,7 +10,7 @@ #include "net/cert/internal/cert_errors.h" #include "net/cert/x509_certificate.h" #include "net/cert/x509_util.h" -#include "net/cert/x509_util_ios_and_mac.h" +#include "net/cert/x509_util_apple.h" namespace net {
diff --git a/net/cert/trial_comparison_cert_verifier.cc b/net/cert/trial_comparison_cert_verifier.cc index a778355ca..a1d93ed 100644 --- a/net/cert/trial_comparison_cert_verifier.cc +++ b/net/cert/trial_comparison_cert_verifier.cc
@@ -31,7 +31,7 @@ base::Value JobResultParams(bool trial_success) { base::Value results(base::Value::Type::DICTIONARY); - results.SetBoolKey("trial_success", trial_success); + results.GetDict().Set("trial_success", trial_success); return results; }
diff --git a/net/cert/x509_util_ios_and_mac.cc b/net/cert/x509_util_apple.cc similarity index 98% rename from net/cert/x509_util_ios_and_mac.cc rename to net/cert/x509_util_apple.cc index 3635389..e5a15ce 100644 --- a/net/cert/x509_util_ios_and_mac.cc +++ b/net/cert/x509_util_apple.cc
@@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#include "net/cert/x509_util_ios_and_mac.h" +#include "net/cert/x509_util_apple.h" #include "base/logging.h" #include "build/build_config.h"
diff --git a/net/cert/x509_util_ios_and_mac.h b/net/cert/x509_util_apple.h similarity index 93% rename from net/cert/x509_util_ios_and_mac.h rename to net/cert/x509_util_apple.h index e195471a..1784414 100644 --- a/net/cert/x509_util_ios_and_mac.h +++ b/net/cert/x509_util_apple.h
@@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#ifndef NET_CERT_X509_UTIL_IOS_AND_MAC_H_ -#define NET_CERT_X509_UTIL_IOS_AND_MAC_H_ +#ifndef NET_CERT_X509_UTIL_APPLE_H_ +#define NET_CERT_X509_UTIL_APPLE_H_ #include <CoreFoundation/CFArray.h> #include <Security/Security.h> @@ -54,4 +54,4 @@ } // namespace net -#endif // NET_CERT_X509_UTIL_IOS_AND_MAC_H_ +#endif // NET_CERT_X509_UTIL_APPLE_H_
diff --git a/net/cert/x509_util_ios_and_mac_unittest.cc b/net/cert/x509_util_apple_unittest.cc similarity index 99% rename from net/cert/x509_util_ios_and_mac_unittest.cc rename to net/cert/x509_util_apple_unittest.cc index 34dcfed..2a49bf9e 100644 --- a/net/cert/x509_util_ios_and_mac_unittest.cc +++ b/net/cert/x509_util_apple_unittest.cc
@@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#include "net/cert/x509_util_ios_and_mac.h" +#include "net/cert/x509_util_apple.h" #include "build/build_config.h" #include "net/cert/x509_certificate.h"
diff --git a/net/http/http_cache_transaction.cc b/net/http/http_cache_transaction.cc index 2007ceb0..d675b12 100644 --- a/net/http/http_cache_transaction.cc +++ b/net/http/http_cache_transaction.cc
@@ -640,6 +640,10 @@ } void HttpCache::Transaction::WriterAboutToBeRemovedFromEntry(int result) { + TRACE_EVENT_WITH_FLOW0( + "io", "HttpCacheTransaction::WriterAboutToBeRemovedFromEntry", + net_log().source().id, + TRACE_EVENT_FLAG_FLOW_IN | TRACE_EVENT_FLAG_FLOW_OUT); // Since the transaction can no longer access the network transaction, save // all network related info now. if (moved_network_transaction_to_writers_ && @@ -658,6 +662,10 @@ } void HttpCache::Transaction::WriteModeTransactionAboutToBecomeReader() { + TRACE_EVENT_WITH_FLOW0( + "io", "HttpCacheTransaction::WriteModeTransactionAboutToBecomeReader", + net_log().source().id, + TRACE_EVENT_FLAG_FLOW_IN | TRACE_EVENT_FLAG_FLOW_OUT); mode_ = READ; if (moved_network_transaction_to_writers_ && entry_->writers->network_transaction()) { @@ -2242,6 +2250,9 @@ } int HttpCache::Transaction::DoFinishHeaders(int result) { + TRACE_EVENT_WITH_FLOW0("io", "HttpCacheTransaction::DoFinishHeaders", + net_log().source().id, + TRACE_EVENT_FLAG_FLOW_IN | TRACE_EVENT_FLAG_FLOW_OUT); if (!cache_.get() || !entry_ || result != OK) { TransitionToState(STATE_NONE); return result; @@ -2272,6 +2283,9 @@ } int HttpCache::Transaction::DoFinishHeadersComplete(int rv) { + TRACE_EVENT_WITH_FLOW0("io", "HttpCacheTransaction::DoFinishHeadersComplete", + net_log().source().id, + TRACE_EVENT_FLAG_FLOW_IN | TRACE_EVENT_FLAG_FLOW_OUT); entry_lock_waiting_since_ = TimeTicks(); if (rv == ERR_CACHE_RACE || rv == ERR_CACHE_LOCK_TIMEOUT) { TransitionToState(STATE_HEADERS_PHASE_CANNOT_PROCEED);
diff --git a/net/http/http_response_headers.cc b/net/http/http_response_headers.cc index 84cd546..0821f596 100644 --- a/net/http/http_response_headers.cc +++ b/net/http/http_response_headers.cc
@@ -1401,4 +1401,4 @@ dict.Add("headers", parsed_); } -} // namespace net +} // namespace net \ No newline at end of file
diff --git a/net/http/http_server_properties_manager.cc b/net/http/http_server_properties_manager.cc index 3513f4f..7f7e696 100644 --- a/net/http/http_server_properties_manager.cc +++ b/net/http/http_server_properties_manager.cc
@@ -107,29 +107,26 @@ void AddAlternativeServiceFieldsToDictionaryValue( const AlternativeService& alternative_service, - base::Value* dict) { - DCHECK(dict->is_dict()); - dict->SetIntKey(kPortKey, alternative_service.port); + base::Value::Dict& dict) { + dict.Set(kPortKey, alternative_service.port); if (!alternative_service.host.empty()) { - dict->SetStringKey(kHostKey, alternative_service.host); + dict.Set(kHostKey, alternative_service.host); } - dict->SetStringKey(kProtocolKey, - NextProtoToString(alternative_service.protocol)); + dict.Set(kProtocolKey, NextProtoToString(alternative_service.protocol)); } // Fails in the case of NetworkIsolationKeys that can't be persisted to disk, // like unique origins. bool TryAddBrokenAlternativeServiceFieldsToDictionaryValue( const BrokenAlternativeService& broken_alt_service, - base::Value* dict) { - DCHECK(dict->is_dict()); + base::Value::Dict& dict) { base::Value network_isolation_key_value; if (!broken_alt_service.network_isolation_key.ToValue( &network_isolation_key_value)) { return false; } - dict->SetKey(kNetworkIsolationKey, std::move(network_isolation_key_value)); + dict.Set(kNetworkIsolationKey, std::move(network_isolation_key_value)); AddAlternativeServiceFieldsToDictionaryValue( broken_alt_service.alternative_service, dict); return true; @@ -153,20 +150,18 @@ (server_id.privacy_mode_enabled() ? "/private" : ""); } -// Takes in a base::Value representing a dictionary, and whether -// NetworkIsolationKeys are enabled for HttpServerProperties, and extracts the -// NetworkIsolationKey stored with the |kNetworkIsolationKey| in the dictionary, -// and writes it to |out_network_isolation_key|. Returns false if unable to load -// a NetworkIsolationKey, or the NetworkIsolationKey is non-empty, but +// Takes in a base::Value::Dict, and whether NetworkIsolationKeys are enabled +// for HttpServerProperties, and extracts the NetworkIsolationKey stored with +// the |kNetworkIsolationKey| in the dictionary, and writes it to +// |out_network_isolation_key|. Returns false if unable to load a +// NetworkIsolationKey, or the NetworkIsolationKey is non-empty, but // |use_network_isolation_key| is false. bool GetNetworkIsolationKeyFromDict( - const base::Value& dict, + const base::Value::Dict& dict, bool use_network_isolation_key, NetworkIsolationKey* out_network_isolation_key) { - DCHECK(dict.is_dict()); - const base::Value* network_isolation_key_value = - dict.FindKey(kNetworkIsolationKey); + dict.Find(kNetworkIsolationKey); NetworkIsolationKey network_isolation_key; if (!network_isolation_key_value || !NetworkIsolationKey::FromValue(*network_isolation_key_value, @@ -230,16 +225,19 @@ net_log_.EndEvent(NetLogEventType::HTTP_SERVER_PROPERTIES_INITIALIZATION); - const base::Value* http_server_properties_dict = + const base::Value* http_server_properties_value = pref_delegate_->GetServerProperties(); // If there are no preferences set, do nothing. - if (!http_server_properties_dict || !http_server_properties_dict->is_dict()) + if (!http_server_properties_value || !http_server_properties_value->is_dict()) return; + const base::Value::Dict& http_server_properties_dict = + http_server_properties_value->GetDict(); + net_log_.AddEvent(NetLogEventType::HTTP_SERVER_PROPERTIES_UPDATE_CACHE, - [&] { return http_server_properties_dict->Clone(); }); + [&] { return http_server_properties_value->Clone(); }); absl::optional<int> maybe_version_number = - http_server_properties_dict->FindIntKey(kVersionKey); + http_server_properties_dict.FindInt(kVersionKey); if (!maybe_version_number.has_value() || *maybe_version_number != kVersionNumber) { DVLOG(1) << "Missing or unsupported. Clearing all properties. " @@ -260,14 +258,14 @@ // ... // ], ... // }, - const base::Value* servers_list = - http_server_properties_dict->FindListKey(kServersKey); + const base::Value::List* servers_list = + http_server_properties_dict.FindList(kServersKey); if (!servers_list) { DVLOG(1) << "Malformed http_server_properties for servers list."; return; } - ReadLastLocalAddressWhenQuicWorked(*http_server_properties_dict, + ReadLastLocalAddressWhenQuicWorked(http_server_properties_dict, last_local_address_when_quic_worked); *server_info_map = std::make_unique<HttpServerProperties::ServerInfoMap>(); @@ -281,24 +279,23 @@ // Iterate servers list in reverse MRU order so that entries are inserted // into |spdy_servers_map|, |alternative_service_map|, and // |server_network_stats_map| from oldest to newest. - for (auto it = servers_list->GetListDeprecated().end(); - it != servers_list->GetListDeprecated().begin();) { + for (auto it = servers_list->end(); it != servers_list->begin();) { --it; if (!it->is_dict()) { DVLOG(1) << "Malformed http_server_properties for servers dictionary."; continue; } - AddServerData(*it, server_info_map->get(), use_network_isolation_key); + AddServerData(it->GetDict(), server_info_map->get(), + use_network_isolation_key); } - AddToQuicServerInfoMap(*http_server_properties_dict, - use_network_isolation_key, + AddToQuicServerInfoMap(http_server_properties_dict, use_network_isolation_key, quic_server_info_map->get()); // Read list containing broken and recently-broken alternative services, if // it exists. - const base::Value* broken_alt_svc_list = - http_server_properties_dict->FindListKey(kBrokenAlternativeServicesKey); + const base::Value::List* broken_alt_svc_list = + http_server_properties_dict.FindList(kBrokenAlternativeServicesKey); if (broken_alt_svc_list) { *broken_alternative_service_list = std::make_unique<BrokenAlternativeServiceList>(); @@ -307,15 +304,15 @@ kMaxRecentlyBrokenAlternativeServiceEntries); // Iterate list in reverse-MRU order - for (auto it = broken_alt_svc_list->GetListDeprecated().end(); - it != broken_alt_svc_list->GetListDeprecated().begin();) { + for (auto it = broken_alt_svc_list->end(); + it != broken_alt_svc_list->begin();) { --it; if (!it->is_dict()) { DVLOG(1) << "Malformed broken alterantive service entry."; continue; } AddToBrokenAlternativeServices( - *it, use_network_isolation_key, + it->GetDict(), use_network_isolation_key, broken_alternative_service_list->get(), recently_broken_alternative_services->get()); } @@ -341,7 +338,7 @@ } void HttpServerPropertiesManager::AddToBrokenAlternativeServices( - const base::Value& broken_alt_svc_entry_dict, + const base::Value::Dict& broken_alt_svc_entry_dict, bool use_network_isolation_key, BrokenAlternativeServiceList* broken_alternative_service_list, RecentlyBrokenAlternativeServices* recently_broken_alternative_services) { @@ -364,9 +361,9 @@ // Read broken-count and add an entry for |alt_service| into // |recently_broken_alternative_services|. - if (broken_alt_svc_entry_dict.FindKey(kBrokenCountKey)) { + if (broken_alt_svc_entry_dict.Find(kBrokenCountKey)) { absl::optional<int> broken_count = - broken_alt_svc_entry_dict.FindIntKey(kBrokenCountKey); + broken_alt_svc_entry_dict.FindInt(kBrokenCountKey); if (!broken_count.has_value()) { DVLOG(1) << "Recently broken alternative service has malformed " << "broken-count."; @@ -385,9 +382,9 @@ // Read broken-until and add an entry for |alt_service| in // |broken_alternative_service_list|. - if (broken_alt_svc_entry_dict.FindKey(kBrokenUntilKey)) { + if (broken_alt_svc_entry_dict.Find(kBrokenUntilKey)) { const std::string* expiration_string = - broken_alt_svc_entry_dict.FindStringKey(kBrokenUntilKey); + broken_alt_svc_entry_dict.FindString(kBrokenUntilKey); int64_t expiration_int64; if (!expiration_string || !base::StringToInt64(*expiration_string, &expiration_int64)) { @@ -415,11 +412,11 @@ } void HttpServerPropertiesManager::AddServerData( - const base::Value& server_dict, + const base::Value::Dict& server_dict, HttpServerProperties::ServerInfoMap* server_info_map, bool use_network_isolation_key) { // Get server's scheme/host/pair. - const std::string* server_str = server_dict.FindStringKey(kServerKey); + const std::string* server_str = server_dict.FindString(kServerKey); NetworkIsolationKey network_isolation_key; // Can't load entry if server name missing, or if the network isolation key is // missing or invalid. @@ -437,7 +434,7 @@ HttpServerProperties::ServerInfo server_info; - server_info.supports_spdy = server_dict.FindBoolKey(kSupportsSpdyKey); + server_info.supports_spdy = server_dict.FindBool(kSupportsSpdyKey); if (ParseAlternativeServiceInfo(spdy_server, server_dict, &server_info)) ParseNetworkStats(spdy_server, server_dict, &server_info); @@ -451,12 +448,12 @@ } bool HttpServerPropertiesManager::ParseAlternativeServiceDict( - const base::Value& dict, + const base::Value::Dict& dict, bool host_optional, const std::string& parsing_under, AlternativeService* alternative_service) { // Protocol is mandatory. - const std::string* protocol_str = dict.FindStringKey(kProtocolKey); + const std::string* protocol_str = dict.FindString(kProtocolKey); if (!protocol_str) { DVLOG(1) << "Malformed alternative service protocol string under: " << parsing_under; @@ -473,8 +470,8 @@ // If host is optional, it defaults to "". std::string host = ""; const std::string* hostp = nullptr; - if (dict.FindKey(kHostKey)) { - hostp = dict.FindStringKey(kHostKey); + if (dict.Find(kHostKey)) { + hostp = dict.FindString(kHostKey); if (!hostp) { DVLOG(1) << "Malformed alternative service host string under: " << parsing_under; @@ -489,7 +486,7 @@ alternative_service->host = host; // Port is mandatory. - absl::optional<int> maybe_port = dict.FindIntKey(kPortKey); + absl::optional<int> maybe_port = dict.FindInt(kPortKey); if (!maybe_port.has_value() || !IsPortValid(maybe_port.value())) { DVLOG(1) << "Malformed alternative service port under: " << parsing_under; return false; @@ -500,7 +497,7 @@ } bool HttpServerPropertiesManager::ParseAlternativeServiceInfoDictOfServer( - const base::Value& dict, + const base::Value::Dict& dict, const std::string& server_str, AlternativeServiceInfo* alternative_service_info) { AlternativeService alternative_service; @@ -511,10 +508,10 @@ alternative_service_info->set_alternative_service(alternative_service); // Expiration is optional, defaults to one day. - if (!dict.FindKey(kExpirationKey)) { + if (!dict.Find(kExpirationKey)) { alternative_service_info->set_expiration(base::Time::Now() + base::Days(1)); } else { - const std::string* expiration_string = dict.FindStringKey(kExpirationKey); + const std::string* expiration_string = dict.FindString(kExpirationKey); if (expiration_string) { int64_t expiration_int64 = 0; if (!base::StringToInt64(*expiration_string, &expiration_int64)) { @@ -532,15 +529,15 @@ } // Advertised versions list is optional. - if (dict.FindKey(kAdvertisedAlpnsKey)) { - const base::Value* versions_list = dict.FindListKey(kAdvertisedAlpnsKey); + if (dict.Find(kAdvertisedAlpnsKey)) { + const base::Value::List* versions_list = dict.FindList(kAdvertisedAlpnsKey); if (!versions_list) { DVLOG(1) << "Malformed alternative service advertised versions list for " << "server: " << server_str; return false; } quic::ParsedQuicVersionVector advertised_versions; - for (const auto& value : versions_list->GetListDeprecated()) { + for (const auto& value : *versions_list) { const std::string* version_string = value.GetIfString(); if (!version_string) { DVLOG(1) << "Malformed alternative service version for server: " @@ -561,11 +558,11 @@ bool HttpServerPropertiesManager::ParseAlternativeServiceInfo( const url::SchemeHostPort& server, - const base::Value& server_pref_dict, + const base::Value::Dict& server_pref_dict, HttpServerProperties::ServerInfo* server_info) { DCHECK(!server_info->alternative_services.has_value()); - const base::Value* alternative_service_list = - server_pref_dict.FindListKey(kAlternativeServiceKey); + const base::Value::List* alternative_service_list = + server_pref_dict.FindList(kAlternativeServiceKey); if (!alternative_service_list) { return true; } @@ -574,14 +571,13 @@ } AlternativeServiceInfoVector alternative_service_info_vector; - for (const auto& alternative_service_list_item : - alternative_service_list->GetListDeprecated()) { + for (const auto& alternative_service_list_item : *alternative_service_list) { if (!alternative_service_list_item.is_dict()) return false; AlternativeServiceInfo alternative_service_info; - if (!ParseAlternativeServiceInfoDictOfServer(alternative_service_list_item, - server.Serialize(), - &alternative_service_info)) { + if (!ParseAlternativeServiceInfoDictOfServer( + alternative_service_list_item.GetDict(), server.Serialize(), + &alternative_service_info)) { return false; } if (base::Time::Now() < alternative_service_info.expiration()) { @@ -598,14 +594,14 @@ } void HttpServerPropertiesManager::ReadLastLocalAddressWhenQuicWorked( - const base::Value& http_server_properties_dict, + const base::Value::Dict& http_server_properties_dict, IPAddress* last_local_address_when_quic_worked) { - const base::Value* supports_quic_dict = - http_server_properties_dict.FindDictKey(kSupportsQuicKey); + const base::Value::Dict* supports_quic_dict = + http_server_properties_dict.FindDict(kSupportsQuicKey); if (!supports_quic_dict) { return; } - const base::Value* used_quic = supports_quic_dict->FindKey(kUsedQuicKey); + const base::Value* used_quic = supports_quic_dict->Find(kUsedQuicKey); if (!used_quic || !used_quic->is_bool()) { DVLOG(1) << "Malformed SupportsQuic"; return; @@ -613,7 +609,7 @@ if (!used_quic->GetBool()) return; - const std::string* address = supports_quic_dict->FindStringKey(kAddressKey); + const std::string* address = supports_quic_dict->FindString(kAddressKey); if (!address || !last_local_address_when_quic_worked->AssignFromIPLiteral(*address)) { DVLOG(1) << "Malformed SupportsQuic"; @@ -622,16 +618,15 @@ void HttpServerPropertiesManager::ParseNetworkStats( const url::SchemeHostPort& server, - const base::Value& server_pref_dict, + const base::Value::Dict& server_pref_dict, HttpServerProperties::ServerInfo* server_info) { DCHECK(!server_info->server_network_stats.has_value()); - const base::Value* server_network_stats_dict = - server_pref_dict.FindDictKey(kNetworkStatsKey); + const base::Value::Dict* server_network_stats_dict = + server_pref_dict.FindDict(kNetworkStatsKey); if (!server_network_stats_dict) { return; } - absl::optional<int> maybe_srtt = - server_network_stats_dict->FindIntKey(kSrttKey); + absl::optional<int> maybe_srtt = server_network_stats_dict->FindInt(kSrttKey); if (!maybe_srtt.has_value()) { DVLOG(1) << "Malformed ServerNetworkStats for server: " << server.Serialize(); @@ -645,23 +640,24 @@ } void HttpServerPropertiesManager::AddToQuicServerInfoMap( - const base::Value& http_server_properties_dict, + const base::Value::Dict& http_server_properties_dict, bool use_network_isolation_key, HttpServerProperties::QuicServerInfoMap* quic_server_info_map) { - const base::Value* quic_server_info_list = - http_server_properties_dict.FindListKey(kQuicServers); + const base::Value::List* quic_server_info_list = + http_server_properties_dict.FindList(kQuicServers); if (!quic_server_info_list) { DVLOG(1) << "Malformed http_server_properties for quic_servers."; return; } - for (const auto& quic_server_info_value : - quic_server_info_list->GetListDeprecated()) { - if (!quic_server_info_value.is_dict()) + for (const auto& quic_server_info_value : *quic_server_info_list) { + const base::Value::Dict* quic_server_info_dict = + quic_server_info_value.GetIfDict(); + if (!quic_server_info_dict) continue; const std::string* quic_server_id_str = - quic_server_info_value.FindStringKey(kQuicServerIdKey); + quic_server_info_dict->FindString(kQuicServerIdKey); if (!quic_server_id_str || quic_server_id_str->empty()) continue; @@ -674,7 +670,7 @@ } NetworkIsolationKey network_isolation_key; - if (!GetNetworkIsolationKeyFromDict(quic_server_info_value, + if (!GetNetworkIsolationKeyFromDict(*quic_server_info_dict, use_network_isolation_key, &network_isolation_key)) { DVLOG(1) << "Malformed http_server_properties quic server dict: " @@ -683,7 +679,7 @@ } const std::string* quic_server_info = - quic_server_info_value.FindStringKey(kServerInfoKey); + quic_server_info_dict->FindString(kServerInfoKey); if (!quic_server_info) { DVLOG(1) << "Malformed http_server_properties quic server info: " << *quic_server_id_str; @@ -714,11 +710,13 @@ std::set<std::pair<std::string, NetworkIsolationKey>> persisted_canonical_suffix_set; const base::Time now = base::Time::Now(); - base::Value http_server_properties_dict(base::Value::Type::DICTIONARY); + base::Value http_server_properties_value(base::Value::Type::DICTIONARY); + base::Value::Dict& http_server_properties_dict = + http_server_properties_value.GetDict(); - // Convert |server_info_map| to a dictionary Value and add it to + // Convert |server_info_map| to a list Value and add it to // |http_server_properties_dict|. - base::Value servers_list(base::Value::Type::LIST); + base::Value::List servers_list; for (const auto& [key, server_info] : base::Reversed(server_info_map)) { // If can't convert the NetworkIsolationKey to a value, don't save to disk. // Generally happens because the key is for a unique origin. @@ -726,139 +724,137 @@ if (!key.network_isolation_key.ToValue(&network_isolation_key_value)) continue; - base::Value server_dict(base::Value::Type::DICTIONARY); + base::Value::Dict server_dict; bool supports_spdy = server_info.supports_spdy.value_or(false); if (supports_spdy) - server_dict.SetBoolKey(kSupportsSpdyKey, supports_spdy); + server_dict.Set(kSupportsSpdyKey, supports_spdy); AlternativeServiceInfoVector alternative_services = GetAlternativeServiceToPersist(server_info.alternative_services, key, now, get_canonical_suffix, &persisted_canonical_suffix_set); if (!alternative_services.empty()) - SaveAlternativeServiceToServerPrefs(alternative_services, &server_dict); + SaveAlternativeServiceToServerPrefs(alternative_services, server_dict); if (server_info.server_network_stats) { SaveNetworkStatsToServerPrefs(*server_info.server_network_stats, - &server_dict); + server_dict); } // Don't add empty entries. This can happen if, for example, all alternative // services are empty, or |supports_spdy| is set to false, and all other // fields are not set. - if (server_dict.DictEmpty()) + if (server_dict.empty()) continue; - server_dict.SetStringKey(kServerKey, key.server.Serialize()); - server_dict.SetKey(kNetworkIsolationKey, - std::move(network_isolation_key_value)); + server_dict.Set(kServerKey, key.server.Serialize()); + server_dict.Set(kNetworkIsolationKey, + std::move(network_isolation_key_value)); servers_list.Append(std::move(server_dict)); } - http_server_properties_dict.SetKey(kServersKey, std::move(servers_list)); + http_server_properties_dict.Set(kServersKey, std::move(servers_list)); - http_server_properties_dict.SetIntKey(kVersionKey, kVersionNumber); + http_server_properties_dict.Set(kVersionKey, kVersionNumber); SaveLastLocalAddressWhenQuicWorkedToPrefs(last_local_address_when_quic_worked, - &http_server_properties_dict); + http_server_properties_dict); SaveQuicServerInfoMapToServerPrefs(quic_server_info_map, - &http_server_properties_dict); + http_server_properties_dict); SaveBrokenAlternativeServicesToPrefs( broken_alternative_service_list, kMaxBrokenAlternativeServicesToPersist, - recently_broken_alternative_services, &http_server_properties_dict); + recently_broken_alternative_services, http_server_properties_dict); - pref_delegate_->SetServerProperties(http_server_properties_dict, + pref_delegate_->SetServerProperties(http_server_properties_value, std::move(callback)); net_log_.AddEvent(NetLogEventType::HTTP_SERVER_PROPERTIES_UPDATE_PREFS, - [&] { return http_server_properties_dict.Clone(); }); + [&] { return http_server_properties_value.Clone(); }); } void HttpServerPropertiesManager::SaveAlternativeServiceToServerPrefs( const AlternativeServiceInfoVector& alternative_service_info_vector, - base::Value* server_pref_dict) { + base::Value::Dict& server_pref_dict) { if (alternative_service_info_vector.empty()) { return; } - base::Value alternative_service_list(base::Value::Type::LIST); + base::Value::List alternative_service_list; for (const AlternativeServiceInfo& alternative_service_info : alternative_service_info_vector) { const AlternativeService& alternative_service = alternative_service_info.alternative_service(); DCHECK(IsAlternateProtocolValid(alternative_service.protocol)); - base::Value alternative_service_dict(base::Value::Type::DICTIONARY); + base::Value::Dict alternative_service_dict; AddAlternativeServiceFieldsToDictionaryValue(alternative_service, - &alternative_service_dict); + alternative_service_dict); // JSON cannot store int64_t, so expiration is converted to a string. - alternative_service_dict.SetStringKey( + alternative_service_dict.Set( kExpirationKey, base::NumberToString( alternative_service_info.expiration().ToInternalValue())); - base::Value advertised_versions_list(base::Value::Type::LIST); + base::Value::List advertised_versions_list; for (const auto& version : alternative_service_info.advertised_versions()) { advertised_versions_list.Append(quic::AlpnForVersion(version)); } - alternative_service_dict.SetKey(kAdvertisedAlpnsKey, - std::move(advertised_versions_list)); + alternative_service_dict.Set(kAdvertisedAlpnsKey, + std::move(advertised_versions_list)); alternative_service_list.Append(std::move(alternative_service_dict)); } - if (alternative_service_list.GetListDeprecated().size() == 0) + if (alternative_service_list.size() == 0) return; - server_pref_dict->SetKey(kAlternativeServiceKey, - std::move(alternative_service_list)); + server_pref_dict.Set(kAlternativeServiceKey, + std::move(alternative_service_list)); } void HttpServerPropertiesManager::SaveLastLocalAddressWhenQuicWorkedToPrefs( const IPAddress& last_local_address_when_quic_worked, - base::Value* http_server_properties_dict) { + base::Value::Dict& http_server_properties_dict) { if (!last_local_address_when_quic_worked.IsValid()) return; - base::Value supports_quic_dict(base::Value::Type::DICTIONARY); - supports_quic_dict.SetBoolKey(kUsedQuicKey, true); - supports_quic_dict.SetStringKey( - kAddressKey, last_local_address_when_quic_worked.ToString()); - http_server_properties_dict->SetKey(kSupportsQuicKey, - std::move(supports_quic_dict)); + base::Value::Dict supports_quic_dict; + supports_quic_dict.Set(kUsedQuicKey, true); + supports_quic_dict.Set(kAddressKey, + last_local_address_when_quic_worked.ToString()); + http_server_properties_dict.Set(kSupportsQuicKey, + std::move(supports_quic_dict)); } void HttpServerPropertiesManager::SaveNetworkStatsToServerPrefs( const ServerNetworkStats& server_network_stats, - base::Value* server_pref_dict) { - base::Value server_network_stats_dict(base::Value::Type::DICTIONARY); - // Becasue JSON doesn't support int64_t, persist int64_t as a string. - server_network_stats_dict.SetIntKey( + base::Value::Dict& server_pref_dict) { + base::Value::Dict server_network_stats_dict; + // Because JSON doesn't support int64_t, persist int64_t as a string. + server_network_stats_dict.Set( kSrttKey, static_cast<int>(server_network_stats.srtt.InMicroseconds())); // TODO(rtenneti): When QUIC starts using bandwidth_estimate, then persist // bandwidth_estimate. - server_pref_dict->SetKey(kNetworkStatsKey, - std::move(server_network_stats_dict)); + server_pref_dict.Set(kNetworkStatsKey, std::move(server_network_stats_dict)); } void HttpServerPropertiesManager::SaveQuicServerInfoMapToServerPrefs( const HttpServerProperties::QuicServerInfoMap& quic_server_info_map, - base::Value* http_server_properties_dict) { + base::Value::Dict& http_server_properties_dict) { if (quic_server_info_map.empty()) return; - base::Value quic_servers_list(base::Value::Type::LIST); + base::Value::List quic_servers_list; for (const auto& [key, server_info] : base::Reversed(quic_server_info_map)) { base::Value network_isolation_key_value; // Don't save entries with ephemeral NIKs. if (!key.network_isolation_key.ToValue(&network_isolation_key_value)) continue; - base::Value quic_server_pref_dict(base::Value::Type::DICTIONARY); - quic_server_pref_dict.SetStringKey(kQuicServerIdKey, - QuicServerIdToString(key.server_id)); - quic_server_pref_dict.SetKey(kNetworkIsolationKey, - std::move(network_isolation_key_value)); - quic_server_pref_dict.SetStringKey(kServerInfoKey, server_info); + base::Value::Dict quic_server_pref_dict; + quic_server_pref_dict.Set(kQuicServerIdKey, + QuicServerIdToString(key.server_id)); + quic_server_pref_dict.Set(kNetworkIsolationKey, + std::move(network_isolation_key_value)); + quic_server_pref_dict.Set(kServerInfoKey, server_info); quic_servers_list.Append(std::move(quic_server_pref_dict)); } - http_server_properties_dict->SetKey(kQuicServers, - std::move(quic_servers_list)); + http_server_properties_dict.Set(kQuicServers, std::move(quic_servers_list)); } void HttpServerPropertiesManager::SaveBrokenAlternativeServicesToPrefs( @@ -866,7 +862,7 @@ size_t max_broken_alternative_services, const RecentlyBrokenAlternativeServices& recently_broken_alternative_services, - base::Value* http_server_properties_dict) { + base::Value::Dict& http_server_properties_dict) { if (broken_alternative_service_list.empty() && recently_broken_alternative_services.empty()) { return; @@ -874,7 +870,7 @@ // JSON list will be in MRU order according to // |recently_broken_alternative_services|. - base::Value json_list(base::Value::Type::LIST); + base::Value::List json_list; // Maps recently-broken alternative services to the index where it's stored // in |json_list|. @@ -883,14 +879,13 @@ if (!recently_broken_alternative_services.empty()) { for (const auto& [broken_alt_service, broken_count] : base::Reversed(recently_broken_alternative_services)) { - base::Value entry_dict(base::Value::Type::DICTIONARY); + base::Value::Dict entry_dict; if (!TryAddBrokenAlternativeServiceFieldsToDictionaryValue( - broken_alt_service, &entry_dict)) { + broken_alt_service, entry_dict)) { continue; } - entry_dict.SetKey(kBrokenCountKey, base::Value(broken_count)); - json_list_index_map[broken_alt_service] = - json_list.GetListDeprecated().size(); + entry_dict.Set(kBrokenCountKey, broken_count); + json_list_index_map[broken_alt_service] = json_list.size(); json_list.Append(std::move(entry_dict)); } } @@ -914,20 +909,18 @@ auto index_map_it = json_list_index_map.find(broken_alt_service); if (index_map_it != json_list_index_map.end()) { size_t json_list_index = index_map_it->second; - base::Value& entry_dict = - json_list.GetListDeprecated()[json_list_index]; + base::Value& entry_dict = json_list[json_list_index]; DCHECK(entry_dict.is_dict()); - DCHECK(!entry_dict.FindKey(kBrokenUntilKey)); - entry_dict.SetKey(kBrokenUntilKey, - base::Value(base::NumberToString(expiration_int64))); + DCHECK(!entry_dict.GetDict().Find(kBrokenUntilKey)); + entry_dict.GetDict().Set(kBrokenUntilKey, + base::NumberToString(expiration_int64)); } else { - base::Value entry_dict(base::Value::Type::DICTIONARY); + base::Value::Dict entry_dict; if (!TryAddBrokenAlternativeServiceFieldsToDictionaryValue( - broken_alt_service, &entry_dict)) { + broken_alt_service, entry_dict)) { continue; } - entry_dict.SetKey(kBrokenUntilKey, - base::Value(base::NumberToString(expiration_int64))); + entry_dict.Set(kBrokenUntilKey, base::NumberToString(expiration_int64)); json_list.Append(std::move(entry_dict)); } } @@ -935,11 +928,11 @@ // This can happen if all the entries are for NetworkIsolationKeys for opaque // origins, which isn't exactly common, but can theoretically happen. - if (json_list.GetListDeprecated().empty()) + if (json_list.empty()) return; - http_server_properties_dict->SetKey(kBrokenAlternativeServicesKey, - std::move(json_list)); + http_server_properties_dict.Set(kBrokenAlternativeServicesKey, + std::move(json_list)); } void HttpServerPropertiesManager::OnHttpServerPropertiesLoaded() {
diff --git a/net/http/http_server_properties_manager.h b/net/http/http_server_properties_manager.h index add11943..9be5265 100644 --- a/net/http/http_server_properties_manager.h +++ b/net/http/http_server_properties_manager.h
@@ -132,7 +132,7 @@ FRIEND_TEST_ALL_PREFIXES(HttpServerPropertiesManagerTest, AdvertisedVersionsRoundTrip); - void AddServerData(const base::Value& server_dict, + void AddServerData(const base::Value::Dict& server_dict, HttpServerProperties::ServerInfoMap* server_info_map, bool use_network_isolation_key); @@ -146,13 +146,13 @@ // |alternative_service| is the output of parsing |dict|. // Return value is true if parsing is successful. static bool ParseAlternativeServiceDict( - const base::Value& dict, + const base::Value::Dict& dict, bool host_optional, const std::string& parsing_under, AlternativeService* alternative_service); static bool ParseAlternativeServiceInfoDictOfServer( - const base::Value& dict, + const base::Value::Dict& dict, const std::string& server_str, AlternativeServiceInfo* alternative_service_info); @@ -161,43 +161,43 @@ // not considered corruption). static bool ParseAlternativeServiceInfo( const url::SchemeHostPort& server, - const base::Value& server_dict, + const base::Value::Dict& server_dict, HttpServerProperties::ServerInfo* server_info); void ReadLastLocalAddressWhenQuicWorked( - const base::Value& server_dict, + const base::Value::Dict& server_dict, IPAddress* last_local_address_when_quic_worked); void ParseNetworkStats(const url::SchemeHostPort& server, - const base::Value& server_dict, + const base::Value::Dict& server_dict, HttpServerProperties::ServerInfo* server_info); void AddToQuicServerInfoMap( - const base::Value& server_dict, + const base::Value::Dict& server_dict, bool use_network_isolation_key, HttpServerProperties::QuicServerInfoMap* quic_server_info_map); void AddToBrokenAlternativeServices( - const base::Value& broken_alt_svc_entry_dict, + const base::Value::Dict& broken_alt_svc_entry_dict, bool use_network_isolation_key, BrokenAlternativeServiceList* broken_alternative_service_list, RecentlyBrokenAlternativeServices* recently_broken_alternative_services); void SaveAlternativeServiceToServerPrefs( const AlternativeServiceInfoVector& alternative_service_info_vector, - base::Value* server_pref_dict); + base::Value::Dict& server_pref_dict); void SaveLastLocalAddressWhenQuicWorkedToPrefs( const IPAddress& last_local_address_when_quic_worked, - base::Value* http_server_properties_dict); + base::Value::Dict& http_server_properties_dict); void SaveNetworkStatsToServerPrefs( const ServerNetworkStats& server_network_stats, - base::Value* server_pref_dict); + base::Value::Dict& server_pref_dict); void SaveQuicServerInfoMapToServerPrefs( const HttpServerProperties::QuicServerInfoMap& quic_server_info_map, - base::Value* http_server_properties_dict); + base::Value::Dict& http_server_properties_dict); void SaveBrokenAlternativeServicesToPrefs( const BrokenAlternativeServiceList& broken_alternative_service_list, size_t max_broken_alternative_services, const RecentlyBrokenAlternativeServices& recently_broken_alternative_services, - base::Value* http_server_properties_dict); + base::Value::Dict& http_server_properties_dict); void OnHttpServerPropertiesLoaded();
diff --git a/net/http/http_server_properties_manager_unittest.cc b/net/http/http_server_properties_manager_unittest.cc index bda011f..df6ed2e 100644 --- a/net/http/http_server_properties_manager_unittest.cc +++ b/net/http/http_server_properties_manager_unittest.cc
@@ -279,9 +279,9 @@ // Returns a dictionary with only the version field populated. static base::Value DictWithVersion() { - base::Value http_server_properties_dict(base::Value::Type::DICTIONARY); - http_server_properties_dict.SetIntKey("version", 5); - return http_server_properties_dict; + base::Value::Dict http_server_properties_dict; + http_server_properties_dict.Set("version", 5); + return base::Value(std::move(http_server_properties_dict)); } raw_ptr<MockPrefDelegate> @@ -292,43 +292,42 @@ }; TEST_F(HttpServerPropertiesManagerTest, BadCachedHostPortPair) { - base::Value server_pref_dict(base::Value::Type::DICTIONARY); + base::Value::Dict server_pref_dict; // Set supports_spdy for www.google.com:65536. - server_pref_dict.SetBoolKey("supports_spdy", true); + server_pref_dict.Set("supports_spdy", true); // Set up alternative_service for www.google.com:65536. - base::Value alternative_service_dict(base::Value::Type::DICTIONARY); - alternative_service_dict.SetStringKey("protocol_str", "h2"); - alternative_service_dict.SetIntKey("port", 80); - base::Value alternative_service_list(base::Value::Type::LIST); + base::Value::Dict alternative_service_dict; + alternative_service_dict.Set("protocol_str", "h2"); + alternative_service_dict.Set("port", 80); + base::Value::List alternative_service_list; alternative_service_list.Append(std::move(alternative_service_dict)); - server_pref_dict.SetKey("alternative_service", - std::move(alternative_service_list)); + server_pref_dict.Set("alternative_service", + std::move(alternative_service_list)); // Set up ServerNetworkStats for www.google.com:65536. - base::Value stats(base::Value::Type::DICTIONARY); - stats.SetIntKey("srtt", 10); - server_pref_dict.SetKey("network_stats", std::move(stats)); + base::Value::Dict stats; + stats.Set("srtt", 10); + server_pref_dict.Set("network_stats", std::move(stats)); // Set the server preference for www.google.com:65536. - base::Value servers_dict(base::Value::Type::DICTIONARY); - servers_dict.SetKey("www.google.com:65536", std::move(server_pref_dict)); - base::Value servers_list(base::Value::Type::LIST); + base::Value::Dict servers_dict; + servers_dict.Set("www.google.com:65536", std::move(server_pref_dict)); + base::Value::List servers_list; servers_list.Append(std::move(servers_dict)); base::Value http_server_properties_dict = DictWithVersion(); - http_server_properties_dict.SetKey("servers", std::move(servers_list)); + http_server_properties_dict.GetDict().Set("servers", std::move(servers_list)); // Set quic_server_info for www.google.com:65536. - base::Value quic_servers_dict(base::Value::Type::DICTIONARY); - base::Value quic_server_pref_dict1(base::Value::Type::DICTIONARY); - quic_server_pref_dict1.SetKey("server_info", - base::Value("quic_server_info1")); - quic_servers_dict.SetKey("http://mail.google.com:65536", - std::move(quic_server_pref_dict1)); + base::Value::Dict quic_servers_dict; + base::Value::Dict quic_server_pref_dict1; + quic_server_pref_dict1.Set("server_info", "quic_server_info1"); + quic_servers_dict.Set("http://mail.google.com:65536", + std::move(quic_server_pref_dict1)); - http_server_properties_dict.SetKey("quic_servers", - std::move(quic_servers_dict)); + http_server_properties_dict.GetDict().Set("quic_servers", + std::move(quic_servers_dict)); // Set up the pref. InitializePrefs(http_server_properties_dict); @@ -349,27 +348,27 @@ } TEST_F(HttpServerPropertiesManagerTest, BadCachedAltProtocolPort) { - base::Value server_pref_dict(base::Value::Type::DICTIONARY); + base::Value::Dict server_pref_dict; // Set supports_spdy for www.google.com:80. - server_pref_dict.SetBoolKey("supports_spdy", true); + server_pref_dict.Set("supports_spdy", true); // Set up alternative_service for www.google.com:80. - base::Value alternative_service_dict(base::Value::Type::DICTIONARY); - alternative_service_dict.SetStringKey("protocol_str", "h2"); - alternative_service_dict.SetIntKey("port", 65536); - base::Value alternative_service_list(base::Value::Type::LIST); + base::Value::Dict alternative_service_dict; + alternative_service_dict.Set("protocol_str", "h2"); + alternative_service_dict.Set("port", 65536); + base::Value::List alternative_service_list; alternative_service_list.Append(std::move(alternative_service_dict)); - server_pref_dict.SetKey("alternative_service", - std::move(alternative_service_list)); + server_pref_dict.Set("alternative_service", + std::move(alternative_service_list)); // Set the server preference for www.google.com:80. - base::Value servers_dict(base::Value::Type::DICTIONARY); - servers_dict.SetKey("www.google.com:80", std::move(server_pref_dict)); - base::Value servers_list(base::Value::Type::LIST); + base::Value::Dict servers_dict; + servers_dict.Set("www.google.com:80", std::move(server_pref_dict)); + base::Value::List servers_list; servers_list.Append(std::move(servers_dict)); base::Value http_server_properties_dict = DictWithVersion(); - http_server_properties_dict.SetKey("servers", std::move(servers_list)); + http_server_properties_dict.GetDict().Set("servers", std::move(servers_list)); // Set up the pref. InitializePrefs(http_server_properties_dict); @@ -1028,38 +1027,37 @@ // https://crbug.com/444956: Add 200 alternative_service servers followed by // supports_quic and verify we have read supports_quic from prefs. TEST_F(HttpServerPropertiesManagerTest, BadLastLocalAddressWhenQuicWorked) { - base::Value servers_list(base::Value::Type::LIST); + base::Value::List servers_list; for (int i = 1; i <= 200; ++i) { // Set up alternative_service for www.google.com:i. - base::Value server_dict(base::Value::Type::DICTIONARY); - base::Value alternative_service_dict(base::Value::Type::DICTIONARY); - alternative_service_dict.SetStringKey("protocol_str", "quic"); - alternative_service_dict.SetIntKey("port", i); - base::Value alternative_service_list(base::Value::Type::LIST); + base::Value::Dict server_dict; + base::Value::Dict alternative_service_dict; + alternative_service_dict.Set("protocol_str", "quic"); + alternative_service_dict.Set("port", i); + base::Value::List alternative_service_list; alternative_service_list.Append(std::move(alternative_service_dict)); - server_dict.SetKey("alternative_service", - std::move(alternative_service_list)); - server_dict.SetStringKey("server", - StringPrintf("https://www.google.com:%d", i)); - server_dict.SetKey("isolation", base::Value(base::Value::Type::LIST)); + server_dict.Set("alternative_service", std::move(alternative_service_list)); + server_dict.Set("server", StringPrintf("https://www.google.com:%d", i)); + server_dict.Set("isolation", base::Value(base::Value::Type::LIST)); servers_list.Append(std::move(server_dict)); } // Set the server preference for http://mail.google.com server. - base::Value server_dict2(base::Value::Type::DICTIONARY); - server_dict2.SetStringKey("server", "https://mail.google.com"); - server_dict2.SetKey("isolation", base::Value(base::Value::Type::LIST)); + base::Value::Dict server_dict2; + server_dict2.Set("server", "https://mail.google.com"); + server_dict2.Set("isolation", base::Value(base::Value::Type::LIST)); servers_list.Append(std::move(server_dict2)); base::Value http_server_properties_dict = DictWithVersion(); - http_server_properties_dict.SetKey("servers", std::move(servers_list)); + http_server_properties_dict.GetDict().Set("servers", std::move(servers_list)); // Set up SupportsQuic for 127.0.0.1 - base::Value supports_quic(base::Value::Type::DICTIONARY); - supports_quic.SetBoolKey("used_quic", true); - supports_quic.SetStringKey("address", "127.0.0.1"); - http_server_properties_dict.SetKey("supports_quic", std::move(supports_quic)); + base::Value::Dict supports_quic; + supports_quic.Set("used_quic", true); + supports_quic.Set("address", "127.0.0.1"); + http_server_properties_dict.GetDict().Set("supports_quic", + std::move(supports_quic)); // Set up the pref. InitializePrefs(http_server_properties_dict); @@ -1174,17 +1172,16 @@ ASSERT_TRUE(server_dict.is_dict()); // Extract and remove the "broken_until" string for "www.google.com:1234". - base::Value* broken_alt_svc_list = - server_dict.FindListKey("broken_alternative_services"); + base::Value::List* broken_alt_svc_list = + server_dict.GetDict().FindList("broken_alternative_services"); ASSERT_TRUE(broken_alt_svc_list); - ASSERT_EQ(2u, broken_alt_svc_list->GetListDeprecated().size()); - base::Value& broken_alt_svcs_list_entry = - broken_alt_svc_list->GetListDeprecated()[0]; + ASSERT_EQ(2u, broken_alt_svc_list->size()); + base::Value& broken_alt_svcs_list_entry = (*broken_alt_svc_list)[0]; const std::string* broken_until_str = - broken_alt_svcs_list_entry.FindStringKey("broken_until"); + broken_alt_svcs_list_entry.GetDict().FindString("broken_until"); ASSERT_TRUE(broken_until_str); const std::string expiration_string = *broken_until_str; - broken_alt_svcs_list_entry.RemoveKey("broken_until"); + broken_alt_svcs_list_entry.GetDict().Remove("broken_until"); // Expiration time of "www.google.com:1234" should be 5 minutes minus the // update-prefs-delay from when the prefs were written. @@ -1248,7 +1245,7 @@ const url::SchemeHostPort server("https", "example.com", 443); HttpServerProperties::ServerInfo server_info; EXPECT_TRUE(HttpServerPropertiesManager::ParseAlternativeServiceInfo( - server, *server_dict, &server_info)); + server, server_dict->GetDict(), &server_info)); ASSERT_TRUE(server_info.alternative_services.has_value()); AlternativeServiceInfoVector alternative_service_info_vector = @@ -1302,7 +1299,7 @@ const url::SchemeHostPort server("http", "example.com", 80); HttpServerProperties::ServerInfo server_info; EXPECT_FALSE(HttpServerPropertiesManager::ParseAlternativeServiceInfo( - server, *server_dict, &server_info)); + server, server_dict->GetDict(), &server_info)); EXPECT_TRUE(server_info.empty()); } @@ -1352,38 +1349,40 @@ const base::Value* pref_dict = pref_delegate_->GetServerProperties(); - const base::Value* servers_list = pref_dict->FindListKey("servers"); + const base::Value::List* servers_list = + pref_dict->GetDict().FindList("servers"); ASSERT_TRUE(servers_list); - auto it = servers_list->GetListDeprecated().begin(); + auto it = servers_list->begin(); const base::Value& server_pref_dict = *it; ASSERT_TRUE(server_pref_dict.is_dict()); - const std::string* server_str = server_pref_dict.FindStringKey("server"); + const std::string* server_str = + server_pref_dict.GetDict().FindString("server"); ASSERT_TRUE(server_str); EXPECT_EQ("https://www.example.com", *server_str); const base::Value* network_isolation_key_value = - server_pref_dict.FindKey("isolation"); + server_pref_dict.GetDict().Find("isolation"); ASSERT_TRUE(network_isolation_key_value); ASSERT_EQ(base::Value::Type::LIST, network_isolation_key_value->type()); - EXPECT_TRUE(network_isolation_key_value->GetListDeprecated().empty()); + EXPECT_TRUE(network_isolation_key_value->GetList().empty()); - const base::Value* altsvc_list = - server_pref_dict.FindListKey("alternative_service"); + const base::Value::List* altsvc_list = + server_pref_dict.GetDict().FindList("alternative_service"); ASSERT_TRUE(altsvc_list); - ASSERT_EQ(2u, altsvc_list->GetListDeprecated().size()); + ASSERT_EQ(2u, altsvc_list->size()); - const base::Value& altsvc_entry = altsvc_list->GetListDeprecated()[0]; + const base::Value& altsvc_entry = (*altsvc_list)[0]; ASSERT_TRUE(altsvc_entry.is_dict()); - const std::string* hostname = altsvc_entry.FindStringKey("host"); + const std::string* hostname = altsvc_entry.GetDict().FindString("host"); ASSERT_TRUE(hostname); EXPECT_EQ("broken.example.com", *hostname); - const base::Value& altsvc_entry2 = altsvc_list->GetListDeprecated()[1]; + const base::Value& altsvc_entry2 = (*altsvc_list)[1]; ASSERT_TRUE(altsvc_entry.is_dict()); - hostname = altsvc_entry2.FindStringKey("host"); + hostname = altsvc_entry2.GetDict().FindString("host"); ASSERT_TRUE(hostname); EXPECT_EQ("valid.example.com", *hostname); } @@ -1392,27 +1391,27 @@ TEST_F(HttpServerPropertiesManagerTest, DoNotLoadExpiredAlternativeService) { InitializePrefs(); - base::Value alternative_service_list(base::Value::Type::LIST); - base::Value expired_dict(base::Value::Type::DICTIONARY); - expired_dict.SetStringKey("protocol_str", "h2"); - expired_dict.SetStringKey("host", "expired.example.com"); - expired_dict.SetIntKey("port", 443); + base::Value::List alternative_service_list; + base::Value::Dict expired_dict; + expired_dict.Set("protocol_str", "h2"); + expired_dict.Set("host", "expired.example.com"); + expired_dict.Set("port", 443); base::Time time_one_day_ago = base::Time::Now() - base::Days(1); - expired_dict.SetStringKey( - "expiration", base::NumberToString(time_one_day_ago.ToInternalValue())); + expired_dict.Set("expiration", + base::NumberToString(time_one_day_ago.ToInternalValue())); alternative_service_list.Append(std::move(expired_dict)); - base::Value valid_dict(base::Value::Type::DICTIONARY); - valid_dict.SetStringKey("protocol_str", "h2"); - valid_dict.SetStringKey("host", "valid.example.com"); - valid_dict.SetIntKey("port", 443); - valid_dict.SetStringKey( - "expiration", base::NumberToString(one_day_from_now_.ToInternalValue())); + base::Value::Dict valid_dict; + valid_dict.Set("protocol_str", "h2"); + valid_dict.Set("host", "valid.example.com"); + valid_dict.Set("port", 443); + valid_dict.Set("expiration", + base::NumberToString(one_day_from_now_.ToInternalValue())); alternative_service_list.Append(std::move(valid_dict)); - base::Value server_pref_dict(base::Value::Type::DICTIONARY); - server_pref_dict.SetKey("alternative_service", - std::move(alternative_service_list)); + base::Value::Dict server_pref_dict; + server_pref_dict.Set("alternative_service", + std::move(alternative_service_list)); const url::SchemeHostPort server("https", "example.com", 443); HttpServerProperties::ServerInfo server_info; @@ -1554,7 +1553,7 @@ const url::SchemeHostPort server("https", "example.com", 443); HttpServerProperties::ServerInfo server_info; EXPECT_TRUE(HttpServerPropertiesManager::ParseAlternativeServiceInfo( - server, *server_dict, &server_info)); + server, server_dict->GetDict(), &server_info)); ASSERT_TRUE(server_info.alternative_services.has_value()); AlternativeServiceInfoVector alternative_service_info_vector = @@ -3030,14 +3029,14 @@ base::JSONReader::ReadDeprecated(preferences_json); ASSERT_TRUE(preferences_dict); ASSERT_TRUE(preferences_dict->is_dict()); - const base::Value* servers_list = preferences_dict->FindListKey("servers"); + const base::Value::List* servers_list = + preferences_dict->GetDict().FindList("servers"); ASSERT_TRUE(servers_list); - ASSERT_TRUE(servers_list->is_list()); - ASSERT_EQ(servers_list->GetListDeprecated().size(), 1u); - const base::Value& server_dict = servers_list->GetListDeprecated()[0]; + ASSERT_EQ(servers_list->size(), 1u); + const base::Value& server_dict = (*servers_list)[0]; HttpServerProperties::ServerInfo server_info; EXPECT_TRUE(HttpServerPropertiesManager::ParseAlternativeServiceInfo( - server, server_dict, &server_info)); + server, server_dict.GetDict(), &server_info)); ASSERT_TRUE(server_info.alternative_services.has_value()); AlternativeServiceInfoVector alternative_service_info_vector_out = server_info.alternative_services.value();
diff --git a/net/ssl/client_cert_store_mac.cc b/net/ssl/client_cert_store_mac.cc index 1167c12..86f11c0 100644 --- a/net/ssl/client_cert_store_mac.cc +++ b/net/ssl/client_cert_store_mac.cc
@@ -31,7 +31,7 @@ #include "net/cert/internal/extended_key_usage.h" #include "net/cert/internal/parse_certificate.h" #include "net/cert/x509_util.h" -#include "net/cert/x509_util_ios_and_mac.h" +#include "net/cert/x509_util_apple.h" #include "net/cert/x509_util_mac.h" #include "net/ssl/client_cert_identity_mac.h" #include "net/ssl/ssl_platform_key_util.h"
diff --git a/net/ssl/client_cert_store_mac_unittest.cc b/net/ssl/client_cert_store_mac_unittest.cc index 93b5a7d..9b87c634 100644 --- a/net/ssl/client_cert_store_mac_unittest.cc +++ b/net/ssl/client_cert_store_mac_unittest.cc
@@ -9,7 +9,7 @@ #include "base/strings/string_number_conversions.h" #include "net/cert/internal/extended_key_usage.h" #include "net/cert/internal/parse_certificate.h" -#include "net/cert/x509_util_ios_and_mac.h" +#include "net/cert/x509_util_apple.h" #include "net/ssl/client_cert_identity_mac.h" #include "net/ssl/client_cert_identity_test_util.h" #include "net/ssl/client_cert_store_unittest-inl.h"
diff --git a/net/test/keychain_test_util_mac.cc b/net/test/keychain_test_util_mac.cc index f64f6f7..4a29700 100644 --- a/net/test/keychain_test_util_mac.cc +++ b/net/test/keychain_test_util_mac.cc
@@ -8,7 +8,7 @@ #include <Security/SecImportExport.h> #include "base/mac/mac_logging.h" -#include "net/cert/x509_util_ios_and_mac.h" +#include "net/cert/x509_util_apple.h" #include "third_party/boringssl/src/include/openssl/bytestring.h" #include "third_party/boringssl/src/include/openssl/ec_key.h" #include "third_party/boringssl/src/include/openssl/evp.h"
diff --git a/net/url_request/url_request.cc b/net/url_request/url_request.cc index af554caa..fcea2dd 100644 --- a/net/url_request/url_request.cc +++ b/net/url_request/url_request.cc
@@ -671,6 +671,7 @@ maybe_sent_cookies_.clear(); maybe_stored_cookies_.clear(); + has_partitioned_cookie_ = false; GURL referrer_url(referrer_); bool same_origin_for_metrics; @@ -874,6 +875,7 @@ maybe_sent_cookies_.clear(); maybe_stored_cookies_.clear(); + has_partitioned_cookie_ = false; status_ = ERR_IO_PENDING; job_->FollowDeferredRedirect(removed_headers, modified_headers); @@ -885,6 +887,7 @@ maybe_sent_cookies_.clear(); maybe_stored_cookies_.clear(); + has_partitioned_cookie_ = false; status_ = ERR_IO_PENDING; job_->SetAuth(credentials);
diff --git a/net/url_request/url_request.h b/net/url_request/url_request.h index 0aeb701..3b18dd178 100644 --- a/net/url_request/url_request.h +++ b/net/url_request/url_request.h
@@ -812,6 +812,9 @@ base::WeakPtr<URLRequest> GetWeakPtr(); + bool HasPartitionedCookie() { return has_partitioned_cookie_; } + void SetHasPartitionedCookie() { has_partitioned_cookie_ = true; } + protected: // Allow the URLRequestJob class to control the is_pending() flag. void set_is_pending(bool value) { is_pending_ = value; } @@ -1064,6 +1067,11 @@ bool send_client_certs_ = true; + // This boolean is set to true if the response has a Set-Cookie header with + // the Partitioned attribute. + // TODO(https://crbug.com/1296161): Delete this field. + bool has_partitioned_cookie_ = false; + // Idempotency of the request. Idempotency idempotency_ = DEFAULT_IDEMPOTENCY;
diff --git a/net/url_request/url_request_http_job.cc b/net/url_request/url_request_http_job.cc index e889bb0..619573f 100644 --- a/net/url_request/url_request_http_job.cc +++ b/net/url_request/url_request_http_job.cc
@@ -54,6 +54,7 @@ #include "net/cookies/cookie_store.h" #include "net/cookies/cookie_util.h" #include "net/cookies/first_party_set_metadata.h" +#include "net/cookies/parsed_cookie.h" #include "net/cookies/same_party_context.h" #include "net/filter/brotli_source_stream.h" #include "net/filter/filter_source_stream.h" @@ -929,6 +930,12 @@ num_cookie_lines_left_++; + // `cookie_partition_key_` is only non-null when partitioned cookie are + // enabled. + if (cookie_partition_key_ && ParsedCookie(cookie_string).IsPartitioned()) { + request_->SetHasPartitionedCookie(); + } + std::unique_ptr<CanonicalCookie> cookie = net::CanonicalCookie::Create( request_->url(), cookie_string, base::Time::Now(), server_time, cookie_partition_key_.value(), &returned_status);
diff --git a/net/url_request/url_request_http_job_unittest.cc b/net/url_request/url_request_http_job_unittest.cc index d6e1d23..a53909f 100644 --- a/net/url_request/url_request_http_job_unittest.cc +++ b/net/url_request/url_request_http_job_unittest.cc
@@ -2150,6 +2150,8 @@ ASSERT_TRUE(req->is_pending()); delegate.RunUntilComplete(); + ASSERT_TRUE(req->HasPartitionedCookie()); + { // Test request from the same top-level site. TestDelegate delegate; std::unique_ptr<URLRequest> req(context->CreateRequest(
diff --git a/net/url_request/url_request_unittest.cc b/net/url_request/url_request_unittest.cc index 565e2a16..62a007b1 100644 --- a/net/url_request/url_request_unittest.cc +++ b/net/url_request/url_request_unittest.cc
@@ -6235,12 +6235,11 @@ std::unique_ptr<base::Value> value( base::JSONReader::ReadDeprecated(mock_report_sender.latest_report())); ASSERT_TRUE(value); - ASSERT_TRUE(value->is_dict()); - base::DictionaryValue* report_dict; - ASSERT_TRUE(value->GetAsDictionary(&report_dict)); - std::string report_hostname; - EXPECT_TRUE(report_dict->GetString("hostname", &report_hostname)); - EXPECT_EQ(test_server_hostname, report_hostname); + base::Value::Dict* report_dict = value->GetIfDict(); + ASSERT_TRUE(report_dict); + std::string* report_hostname = report_dict->FindString("hostname"); + ASSERT_TRUE(report_hostname); + EXPECT_EQ(test_server_hostname, *report_hostname); EXPECT_EQ(isolation_info.network_isolation_key(), mock_report_sender.latest_network_isolation_key()); }
diff --git a/services/network/public/cpp/client_hints.cc b/services/network/public/cpp/client_hints.cc index c6ba0c5..da24d493 100644 --- a/services/network/public/cpp/client_hints.cc +++ b/services/network/public/cpp/client_hints.cc
@@ -55,8 +55,6 @@ "sec-ch-ua-full-version-list"}, {network::mojom::WebClientHintsType::kFullUserAgent, "sec-ch-ua-full"}, {network::mojom::WebClientHintsType::kUAWoW64, "sec-ch-ua-wow64"}, - {network::mojom::WebClientHintsType::kPartitionedCookies, - "sec-ch-partitioned-cookies"}, {network::mojom::WebClientHintsType::kSaveData, "save-data"}, }; }
diff --git a/services/network/public/cpp/cors/cors.cc b/services/network/public/cpp/cors/cors.cc index 9b1aa04..4a0b703 100644 --- a/services/network/public/cpp/cors/cors.cc +++ b/services/network/public/cpp/cors/cors.cc
@@ -382,12 +382,6 @@ "sec-ch-ua-full", "sec-ch-ua-wow64", - - // The `Sec-CH-UA-Reduced` header field is a temporary client hint, which - // will only be sent in the presence of a valid Origin Trial token. It - // was introduced to enable safely experimenting with cookies set with the - // Partitioned attribute. - "sec-ch-partitioned-cookies", }); if (!base::Contains(safe_names, lower_name))
diff --git a/services/network/public/cpp/cors/cors_unittest.cc b/services/network/public/cpp/cors/cors_unittest.cc index aa97fb7..0d3a083e 100644 --- a/services/network/public/cpp/cors/cors_unittest.cc +++ b/services/network/public/cpp/cors/cors_unittest.cc
@@ -292,7 +292,6 @@ EXPECT_TRUE(IsCorsSafelistedHeader("Sec-CH-UA-Model", "\"Model!\"")); EXPECT_TRUE(IsCorsSafelistedHeader("Sec-CH-UA-Reduced", "\"?1\"")); EXPECT_TRUE(IsCorsSafelistedHeader("Sec-CH-UA-Full", "\"?1\"")); - EXPECT_TRUE(IsCorsSafelistedHeader("Sec-CH-Partitioned-Cookies", "\"?1\"")); // TODO(mkwst): Validate that `Sec-CH-UA-*` is a structured header. // https://crbug.com/924969
diff --git a/services/network/public/cpp/url_loader_completion_status.cc b/services/network/public/cpp/url_loader_completion_status.cc index b84d6d2..225c9fd6 100644 --- a/services/network/public/cpp/url_loader_completion_status.cc +++ b/services/network/public/cpp/url_loader_completion_status.cc
@@ -43,7 +43,8 @@ blocked_by_response_reason == rhs.blocked_by_response_reason && should_report_corb_blocking == rhs.should_report_corb_blocking && proxy_server == rhs.proxy_server && - should_collapse_initiator == rhs.should_collapse_initiator; + should_collapse_initiator == rhs.should_collapse_initiator && + pervasive_payload_requested == rhs.pervasive_payload_requested; } void URLLoaderCompletionStatus::WriteIntoTrace(
diff --git a/services/network/public/cpp/url_loader_completion_status.h b/services/network/public/cpp/url_loader_completion_status.h index 9b7ee802..57c85d12 100644 --- a/services/network/public/cpp/url_loader_completion_status.h +++ b/services/network/public/cpp/url_loader_completion_status.h
@@ -104,6 +104,9 @@ // Whether the initiator of this request should be collapsed. bool should_collapse_initiator = false; + // Whether a pervasive payload is requested. + bool pervasive_payload_requested = false; + // Write a representation of this struct into a trace. void WriteIntoTrace(perfetto::TracedValue context) const; };
diff --git a/services/network/public/cpp/url_loader_completion_status_mojom_traits.cc b/services/network/public/cpp/url_loader_completion_status_mojom_traits.cc index 17d2573..97c4ba7e 100644 --- a/services/network/public/cpp/url_loader_completion_status_mojom_traits.cc +++ b/services/network/public/cpp/url_loader_completion_status_mojom_traits.cc
@@ -39,6 +39,7 @@ out->decoded_body_length = data.decoded_body_length(); out->should_report_corb_blocking = data.should_report_corb_blocking(); out->should_collapse_initiator = data.should_collapse_initiator(); + out->pervasive_payload_requested = data.pervasive_payload_requested(); return true; }
diff --git a/services/network/public/cpp/url_loader_completion_status_mojom_traits.h b/services/network/public/cpp/url_loader_completion_status_mojom_traits.h index 62a418136..d9c636f 100644 --- a/services/network/public/cpp/url_loader_completion_status_mojom_traits.h +++ b/services/network/public/cpp/url_loader_completion_status_mojom_traits.h
@@ -112,6 +112,11 @@ return status.should_collapse_initiator; } + static bool pervasive_payload_requested( + const network::URLLoaderCompletionStatus& status) { + return status.pervasive_payload_requested; + } + static bool Read(network::mojom::URLLoaderCompletionStatusDataView data, network::URLLoaderCompletionStatus* out); };
diff --git a/services/network/public/mojom/restricted_cookie_manager.mojom b/services/network/public/mojom/restricted_cookie_manager.mojom index 65dfeb8..a9f50a4d 100644 --- a/services/network/public/mojom/restricted_cookie_manager.mojom +++ b/services/network/public/mojom/restricted_cookie_manager.mojom
@@ -71,7 +71,7 @@ url.mojom.Url url, SiteForCookies site_for_cookies, url.mojom.Origin top_frame_origin, CookieManagerGetOptions options, - // TODO(crbug.com/1296161): Delete this arg when partitioned cookies + // TODO(https://crbug.com/1296161): Delete this arg when partitioned cookies // Origin Trial is over. bool partitioned_cookies_runtime_feature_enabled) => ( array<CookieWithAccessResult> cookies); @@ -104,7 +104,7 @@ SetCookieFromString(url.mojom.Url url, SiteForCookies site_for_cookies, url.mojom.Origin top_frame_origin, string cookie, - // TODO(crbug.com/1296161): Delete this arg when + // TODO(https://crbug.com/1296161): Delete this arg when // partitioned cookies Origin Trial is over. bool partitioned_cookies_runtime_feature_enabled) => (); @@ -119,7 +119,7 @@ GetCookiesString(url.mojom.Url url, SiteForCookies site_for_cookies, url.mojom.Origin top_frame_origin, - // TODO(crbug.com/1296161): Delete this arg when + // TODO(https://crbug.com/1296161): Delete this arg when // partitioned cookies Origin Trial is over. bool partitioned_cookies_runtime_feature_enabled) => ( string cookies); @@ -134,4 +134,8 @@ url.mojom.Url url, SiteForCookies site_for_cookies, url.mojom.Origin top_frame_origin) => (bool cookies_enabled); + + // TODO(https://crbug.com/1296161): Delete this method when the partitioned cookies + // Origin Trial is over. + ConvertPartitionedCookiesToUnpartitioned(url.mojom.Url url); };
diff --git a/services/network/public/mojom/url_loader_completion_status.mojom b/services/network/public/mojom/url_loader_completion_status.mojom index 1d24664..44deb2e 100644 --- a/services/network/public/mojom/url_loader_completion_status.mojom +++ b/services/network/public/mojom/url_loader_completion_status.mojom
@@ -73,4 +73,7 @@ // Whether the initiator of this request should be collapsed. bool should_collapse_initiator = false; + + // Whether a pervasive payload is requested. + bool pervasive_payload_requested = false; };
diff --git a/services/network/public/mojom/url_response_head.mojom b/services/network/public/mojom/url_response_head.mojom index 4c4cc16d..5a0cc45 100644 --- a/services/network/public/mojom/url_response_head.mojom +++ b/services/network/public/mojom/url_response_head.mojom
@@ -239,4 +239,8 @@ // algorithm. // See: https://fetch.spec.whatwg.org/#concept-http-network-fetch bool request_include_credentials = true; + + // This boolean is set to true if the response has a Set-Cookie header with the Partitioned attribute. + // TODO(https://crbug.com/1296161): Delete this field. + bool has_partitioned_cookie = false; };
diff --git a/services/network/public/mojom/web_client_hints_types.mojom b/services/network/public/mojom/web_client_hints_types.mojom index 9f356155..903c76d8 100644 --- a/services/network/public/mojom/web_client_hints_types.mojom +++ b/services/network/public/mojom/web_client_hints_types.mojom
@@ -51,9 +51,7 @@ // header contains the full user agent string. kFullUserAgent = 24, kUAWoW64 = 25, - // A client hint which, if set, signifies to the origin that the client - // supports the Partitioned cookie attribute. - kPartitionedCookies = 26, + // kPartitionedCookies = 26, Removed in M103. // Indicates the client wants to minimize data transfer if set to 'on'. kSaveData = 27,
diff --git a/services/network/restricted_cookie_manager.cc b/services/network/restricted_cookie_manager.cc index 5b0677e1..cee388b 100644 --- a/services/network/restricted_cookie_manager.cc +++ b/services/network/restricted_cookie_manager.cc
@@ -905,4 +905,14 @@ return false; } +void RestrictedCookieManager::ConvertPartitionedCookiesToUnpartitioned( + const GURL& url) { + DCHECK(base::FeatureList::IsEnabled(net::features::kPartitionedCookies)); + if (base::FeatureList::IsEnabled( + net::features::kPartitionedCookiesBypassOriginTrial)) { + return; + } + cookie_store_->ConvertPartitionedCookiesToUnpartitioned(url); +} + } // namespace network
diff --git a/services/network/restricted_cookie_manager.h b/services/network/restricted_cookie_manager.h index 3b4d5116..158a6fc 100644 --- a/services/network/restricted_cookie_manager.h +++ b/services/network/restricted_cookie_manager.h
@@ -146,6 +146,20 @@ const net::IsolationInfo& isolation_info, base::OnceCallback<void(net::FirstPartySetMetadata)> callback); + // This is a temporary method for the partitioned cookies (aka CHIPS) origin + // trial. + // + // This method allows RCM to convert any sites' partitioned cookies to + // unpartitioned. It should only exist for the duration of the CHIPS OT and + // should be deleted shortly after, since it gives untrusted processes the + // ability to convert any site's partitioned cookies to unpartitioned. + // + // Since CHIPS is still an experimental API, giving RCM this privilege should + // not be a major risk. However, before CHIPS goes live this method should be + // deleted. + // TODO(https://crbug.com/1296161): Delete this function. + void ConvertPartitionedCookiesToUnpartitioned(const GURL& url) override; + private: // The state associated with a CookieChangeListener. class Listener;
diff --git a/services/network/restricted_cookie_manager_unittest.cc b/services/network/restricted_cookie_manager_unittest.cc index aefc211..7ca5b6f 100644 --- a/services/network/restricted_cookie_manager_unittest.cc +++ b/services/network/restricted_cookie_manager_unittest.cc
@@ -2453,6 +2453,91 @@ EXPECT_TRUE(cookies[0].IsPartitioned()); } +TEST_P(PartitionedCookiesRestrictedCookieManagerTest, + ConvertPartitionedCookiesToUnpartitioned) { + const GURL kCookieURL("https://example.com"); + const GURL kTopFrameURL("https://sub.foo.com"); + const net::SiteForCookies kSiteForCookies = + net::SiteForCookies::FromUrl(kTopFrameURL); + const url::Origin kTopFrameOrigin = url::Origin::Create(kTopFrameURL); + const net::IsolationInfo kIsolationInfo = + net::IsolationInfo::CreateForInternalRequest(kTopFrameOrigin); + + service_->OverrideIsolationInfoForTesting(kIsolationInfo); + + sync_service_->SetCookieFromString( + kCookieURL, kSiteForCookies, kTopFrameOrigin, + "__Host-foo=bar; Secure; SameSite=None; Path=/; Partitioned", + /*partitioned_cookies_runtime_feature_enabled=*/true); + + auto options = mojom::CookieManagerGetOptions::New(); + options->name = ""; + options->match_type = mojom::CookieMatchType::STARTS_WITH; + auto cookies = sync_service_->GetAllForUrl( + kCookieURL, kSiteForCookies, kTopFrameOrigin, std::move(options), + /*partitioned_cookies_runtime_feature_enabled=*/true); + ASSERT_EQ(1u, cookies.size()); + EXPECT_TRUE(cookies[0].IsPartitioned()); + + service_->ConvertPartitionedCookiesToUnpartitioned(kCookieURL); + + // The partitioned cookie should now be unpartitioned. + options = mojom::CookieManagerGetOptions::New(); + options->name = ""; + options->match_type = mojom::CookieMatchType::STARTS_WITH; + cookies = sync_service_->GetAllForUrl( + kCookieURL, kSiteForCookies, kTopFrameOrigin, std::move(options), + /*partitioned_cookies_runtime_feature_enabled=*/true); + ASSERT_EQ(1u, cookies.size()); + EXPECT_FALSE(cookies[0].IsPartitioned()); +} + +TEST_P(PartitionedCookiesRestrictedCookieManagerTest, + ConvertPartitionedCookiesToUnpartitioned_BypassOriginTrial) { + base::test::ScopedFeatureList feature_list; + feature_list.InitWithFeatures( + {net::features::kPartitionedCookies, + net::features::kPartitionedCookiesBypassOriginTrial}, + {}); + + const GURL kCookieURL("https://example.com"); + const GURL kTopFrameURL("https://sub.foo.com"); + const net::SiteForCookies kSiteForCookies = + net::SiteForCookies::FromUrl(kTopFrameURL); + const url::Origin kTopFrameOrigin = url::Origin::Create(kTopFrameURL); + const net::IsolationInfo kIsolationInfo = + net::IsolationInfo::CreateForInternalRequest(kTopFrameOrigin); + + service_->OverrideIsolationInfoForTesting(kIsolationInfo); + + sync_service_->SetCookieFromString( + kCookieURL, kSiteForCookies, kTopFrameOrigin, + "__Host-foo=bar; Secure; SameSite=None; Path=/; Partitioned", + /*partitioned_cookies_runtime_feature_enabled=*/false); + + auto options = mojom::CookieManagerGetOptions::New(); + options->name = ""; + options->match_type = mojom::CookieMatchType::STARTS_WITH; + auto cookies = sync_service_->GetAllForUrl( + kCookieURL, kSiteForCookies, kTopFrameOrigin, std::move(options), + /*partitioned_cookies_runtime_feature_enabled=*/true); + ASSERT_EQ(1u, cookies.size()); + EXPECT_TRUE(cookies[0].IsPartitioned()); + + service_->ConvertPartitionedCookiesToUnpartitioned(kCookieURL); + + // The partitioned cookie should remain partitioned if the origin trial bypass + // is enabled. + options = mojom::CookieManagerGetOptions::New(); + options->name = ""; + options->match_type = mojom::CookieMatchType::STARTS_WITH; + cookies = sync_service_->GetAllForUrl( + kCookieURL, kSiteForCookies, kTopFrameOrigin, std::move(options), + /*partitioned_cookies_runtime_feature_enabled=*/true); + ASSERT_EQ(1u, cookies.size()); + EXPECT_TRUE(cookies[0].IsPartitioned()); +} + INSTANTIATE_TEST_SUITE_P( PartitionedCookies, PartitionedCookiesRestrictedCookieManagerTest,
diff --git a/services/network/url_loader.cc b/services/network/url_loader.cc index 4ea824c..e482fbdd 100644 --- a/services/network/url_loader.cc +++ b/services/network/url_loader.cc
@@ -791,9 +791,13 @@ if (CacheTransparencySettings::Get().PervasivePayloadsEnabled()) { auto index = CacheTransparencySettings::Get().GetIndexForURL(request.url); if (index.has_value()) { + // Remember that a pervasive payload was found so we can annotate the + // URLLoaderCompletionStatus with it later. + pervasive_payload_requested_ = true; url_request_->set_pervasive_payloads_index_for_logging(index.value()); base::UmaHistogramExactLinear("Network.CacheTransparency.URLMatched", index.value(), 101); + DVLOG(2) << "Found pervasive payload: " << request.url.spec(); } } @@ -1399,6 +1403,8 @@ response->client_address_space = private_network_access_checker_.ClientAddressSpace(); + response->has_partitioned_cookie = url_request_->HasPartitionedCookie(); + return response; } @@ -2136,6 +2142,8 @@ status.ssl_info = url_request_->ssl_info(); } + status.pervasive_payload_requested = pervasive_payload_requested_; + url_loader_client_.Get()->OnComplete(status); }
diff --git a/services/network/url_loader.h b/services/network/url_loader.h index a64d99a..d545b333b 100644 --- a/services/network/url_loader.h +++ b/services/network/url_loader.h
@@ -505,6 +505,9 @@ // Stores any CORS error encountered while processing |url_request_|. absl::optional<CorsErrorStatus> cors_error_status_; + // True if a pervasive payload is found, for logging purposes. + bool pervasive_payload_requested_ = false; + // Used when deferring sending the data to the client until mime sniffing is // finished. mojom::URLResponseHeadPtr response_;
diff --git a/services/network/url_loader_unittest.cc b/services/network/url_loader_unittest.cc index 8a77c0d..9232834 100644 --- a/services/network/url_loader_unittest.cc +++ b/services/network/url_loader_unittest.cc
@@ -7303,6 +7303,14 @@ #endif // BUILDFLAG(IS_ANDROID) +TEST_F(URLLoaderTest, HasPartitionedCookie) { + TestURLLoaderClient loader_client; + ResourceRequest request = CreateResourceRequest( + "GET", test_server()->GetURL("/set-cookie?a=b;Partitioned;")); + EXPECT_EQ(net::OK, LoadRequest(request)); + EXPECT_TRUE(client_.response_head()->has_partitioned_cookie); +} + class URLLoaderCacheTransparencyTest : public URLLoaderTest { public: void SetUp() override {
diff --git a/testing/buildbot/chromium.android.fyi.json b/testing/buildbot/chromium.android.fyi.json index befddf8c..92f7095 100644 --- a/testing/buildbot/chromium.android.fyi.json +++ b/testing/buildbot/chromium.android.fyi.json
@@ -8240,15 +8240,15 @@ { "args": [ "--additional-apk=apks/WebLayerShellSystemWebView.apk", + "--webview-apk-path=apks/SystemWebView.apk", "--test-runner-outdir", ".", + "--client-outdir", + "../../weblayer_instrumentation_test_M102/out/Release", "--implementation-outdir", ".", "--test-expectations", "../../weblayer/browser/android/javatests/skew/expectations.txt", - "--webview-apk-path=apks/SystemWebView.apk", - "--client-outdir", - "../../weblayer_instrumentation_test_M102/out/Release", "--client-version=102", "--gs-results-bucket=chromium-result-details", "--recover-devices", @@ -8274,7 +8274,7 @@ { "cipd_package": "chromium/testing/weblayer-x86", "location": "weblayer_instrumentation_test_M102", - "revision": "version:102.0.5005.74" + "revision": "version:102.0.5005.75" }, { "cipd_package": "infra/tools/luci/logdog/butler/${platform}", @@ -8750,15 +8750,15 @@ { "args": [ "--additional-apk=apks/WebLayerShellSystemWebView.apk", + "--webview-apk-path=apks/AOSP_SystemWebView.apk", "--test-runner-outdir", ".", "--client-outdir", ".", - "--test-expectations", - "../../weblayer/browser/android/javatests/skew/expectations.txt", - "--webview-apk-path=apks/AOSP_SystemWebView.apk", "--implementation-outdir", "../../weblayer_instrumentation_test_M102/out/Release", + "--test-expectations", + "../../weblayer/browser/android/javatests/skew/expectations.txt", "--impl-version=102", "--gs-results-bucket=chromium-result-details", "--recover-devices", @@ -8784,7 +8784,7 @@ { "cipd_package": "chromium/testing/weblayer-x86", "location": "weblayer_instrumentation_test_M102", - "revision": "version:102.0.5005.74" + "revision": "version:102.0.5005.75" }, { "cipd_package": "infra/tools/luci/logdog/butler/${platform}",
diff --git a/testing/buildbot/chromium.android.json b/testing/buildbot/chromium.android.json index bb6ae54..609cc124 100644 --- a/testing/buildbot/chromium.android.json +++ b/testing/buildbot/chromium.android.json
@@ -46214,15 +46214,15 @@ { "args": [ "--additional-apk=apks/WebLayerShellSystemWebView.apk", + "--webview-apk-path=apks/SystemWebView.apk", "--test-runner-outdir", ".", + "--client-outdir", + "../../weblayer_instrumentation_test_M102/out/Release", "--implementation-outdir", ".", "--test-expectations", "../../weblayer/browser/android/javatests/skew/expectations.txt", - "--webview-apk-path=apks/SystemWebView.apk", - "--client-outdir", - "../../weblayer_instrumentation_test_M102/out/Release", "--client-version=102", "--gs-results-bucket=chromium-result-details", "--recover-devices", @@ -46248,7 +46248,7 @@ { "cipd_package": "chromium/testing/weblayer-x86", "location": "weblayer_instrumentation_test_M102", - "revision": "version:102.0.5005.74" + "revision": "version:102.0.5005.75" }, { "cipd_package": "infra/tools/luci/logdog/butler/${platform}", @@ -46724,15 +46724,15 @@ { "args": [ "--additional-apk=apks/WebLayerShellSystemWebView.apk", + "--webview-apk-path=apks/AOSP_SystemWebView.apk", "--test-runner-outdir", ".", "--client-outdir", ".", - "--test-expectations", - "../../weblayer/browser/android/javatests/skew/expectations.txt", - "--webview-apk-path=apks/AOSP_SystemWebView.apk", "--implementation-outdir", "../../weblayer_instrumentation_test_M102/out/Release", + "--test-expectations", + "../../weblayer/browser/android/javatests/skew/expectations.txt", "--impl-version=102", "--gs-results-bucket=chromium-result-details", "--recover-devices", @@ -46758,7 +46758,7 @@ { "cipd_package": "chromium/testing/weblayer-x86", "location": "weblayer_instrumentation_test_M102", - "revision": "version:102.0.5005.74" + "revision": "version:102.0.5005.75" }, { "cipd_package": "infra/tools/luci/logdog/butler/${platform}", @@ -47238,15 +47238,15 @@ { "args": [ "--additional-apk=apks/ChromePublic.apk", + "--webview-apk-path=apks/SystemWebView.apk", "--test-runner-outdir", ".", + "--client-outdir", + "../../weblayer_instrumentation_test_M102/out/Release", "--implementation-outdir", ".", "--test-expectations", "../../weblayer/browser/android/javatests/skew/expectations.txt", - "--webview-apk-path=apks/SystemWebView.apk", - "--client-outdir", - "../../weblayer_instrumentation_test_M102/out/Release", "--client-version=102", "--gs-results-bucket=chromium-result-details", "--recover-devices", @@ -47272,7 +47272,7 @@ { "cipd_package": "chromium/testing/weblayer-x86", "location": "weblayer_instrumentation_test_M102", - "revision": "version:102.0.5005.74" + "revision": "version:102.0.5005.75" }, { "cipd_package": "infra/tools/luci/logdog/butler/${platform}", @@ -47748,15 +47748,15 @@ { "args": [ "--additional-apk=apks/ChromePublic.apk", + "--webview-apk-path=apks/AOSP_SystemWebView.apk", "--test-runner-outdir", ".", "--client-outdir", ".", - "--test-expectations", - "../../weblayer/browser/android/javatests/skew/expectations.txt", - "--webview-apk-path=apks/AOSP_SystemWebView.apk", "--implementation-outdir", "../../weblayer_instrumentation_test_M102/out/Release", + "--test-expectations", + "../../weblayer/browser/android/javatests/skew/expectations.txt", "--impl-version=102", "--gs-results-bucket=chromium-result-details", "--recover-devices", @@ -47782,7 +47782,7 @@ { "cipd_package": "chromium/testing/weblayer-x86", "location": "weblayer_instrumentation_test_M102", - "revision": "version:102.0.5005.74" + "revision": "version:102.0.5005.75" }, { "cipd_package": "infra/tools/luci/logdog/butler/${platform}", @@ -48330,15 +48330,15 @@ { "args": [ "--additional-apk=apks/WebLayerShellSystemWebView.apk", + "--webview-apk-path=apks/SystemWebView.apk", "--test-runner-outdir", ".", + "--client-outdir", + "../../weblayer_instrumentation_test_M102/out/Release", "--implementation-outdir", ".", "--test-expectations", "../../weblayer/browser/android/javatests/skew/expectations.txt", - "--webview-apk-path=apks/SystemWebView.apk", - "--client-outdir", - "../../weblayer_instrumentation_test_M102/out/Release", "--client-version=102", "--gs-results-bucket=chromium-result-details", "--recover-devices", @@ -48364,7 +48364,7 @@ { "cipd_package": "chromium/testing/weblayer-x86", "location": "weblayer_instrumentation_test_M102", - "revision": "version:102.0.5005.74" + "revision": "version:102.0.5005.75" }, { "cipd_package": "infra/tools/luci/logdog/butler/${platform}", @@ -48840,15 +48840,15 @@ { "args": [ "--additional-apk=apks/WebLayerShellSystemWebView.apk", + "--webview-apk-path=apks/SystemWebView.apk", "--test-runner-outdir", ".", "--client-outdir", ".", - "--test-expectations", - "../../weblayer/browser/android/javatests/skew/expectations.txt", - "--webview-apk-path=apks/SystemWebView.apk", "--implementation-outdir", "../../weblayer_instrumentation_test_M102/out/Release", + "--test-expectations", + "../../weblayer/browser/android/javatests/skew/expectations.txt", "--impl-version=102", "--gs-results-bucket=chromium-result-details", "--recover-devices", @@ -48874,7 +48874,7 @@ { "cipd_package": "chromium/testing/weblayer-x86", "location": "weblayer_instrumentation_test_M102", - "revision": "version:102.0.5005.74" + "revision": "version:102.0.5005.75" }, { "cipd_package": "infra/tools/luci/logdog/butler/${platform}", @@ -49422,15 +49422,15 @@ { "args": [ "--additional-apk=apks/WebLayerShellSystemWebView.apk", + "--webview-apk-path=apks/SystemWebView.apk", "--test-runner-outdir", ".", + "--client-outdir", + "../../weblayer_instrumentation_test_M102/out/Release", "--implementation-outdir", ".", "--test-expectations", "../../weblayer/browser/android/javatests/skew/expectations.txt", - "--webview-apk-path=apks/SystemWebView.apk", - "--client-outdir", - "../../weblayer_instrumentation_test_M102/out/Release", "--client-version=102", "--gs-results-bucket=chromium-result-details", "--recover-devices", @@ -49456,7 +49456,7 @@ { "cipd_package": "chromium/testing/weblayer-x86", "location": "weblayer_instrumentation_test_M102", - "revision": "version:102.0.5005.74" + "revision": "version:102.0.5005.75" }, { "cipd_package": "infra/tools/luci/logdog/butler/${platform}", @@ -49932,15 +49932,15 @@ { "args": [ "--additional-apk=apks/WebLayerShellSystemWebView.apk", + "--webview-apk-path=apks/SystemWebView.apk", "--test-runner-outdir", ".", "--client-outdir", ".", - "--test-expectations", - "../../weblayer/browser/android/javatests/skew/expectations.txt", - "--webview-apk-path=apks/SystemWebView.apk", "--implementation-outdir", "../../weblayer_instrumentation_test_M102/out/Release", + "--test-expectations", + "../../weblayer/browser/android/javatests/skew/expectations.txt", "--impl-version=102", "--gs-results-bucket=chromium-result-details", "--recover-devices", @@ -49966,7 +49966,7 @@ { "cipd_package": "chromium/testing/weblayer-x86", "location": "weblayer_instrumentation_test_M102", - "revision": "version:102.0.5005.74" + "revision": "version:102.0.5005.75" }, { "cipd_package": "infra/tools/luci/logdog/butler/${platform}",
diff --git a/testing/buildbot/variants.pyl b/testing/buildbot/variants.pyl index 186ca758..7c8df96 100644 --- a/testing/buildbot/variants.pyl +++ b/testing/buildbot/variants.pyl
@@ -486,16 +486,16 @@ }, 'WEBLAYER_10_AND_M_IMPL_SKEW_TESTS_NTH_MINUS_ONE_MILESTONE': { 'args': [ + '--webview-apk-path=apks/AOSP_SystemWebView.apk', '--test-runner-outdir', '.', '--client-outdir', '.', - '--test-expectations', - '../../weblayer/browser/android/javatests/skew/expectations.txt', - '--webview-apk-path=apks/AOSP_SystemWebView.apk', '--implementation-outdir', '../../weblayer_instrumentation_test_M102/out/Release', - '--impl-version=102' + '--test-expectations', + '../../weblayer/browser/android/javatests/skew/expectations.txt', + '--impl-version=102', ], 'identifier': 'with_impl_from_102', 'swarming': { @@ -503,10 +503,10 @@ { 'cipd_package': 'chromium/testing/weblayer-x86', 'location': 'weblayer_instrumentation_test_M102', - 'revision': 'version:102.0.5005.74' + 'revision': 'version:102.0.5005.75', } - ] - } + ], + }, }, 'WEBLAYER_10_AND_M_IMPL_SKEW_TESTS_NTH_MINUS_TWO_MILESTONE': { 'args': [ @@ -630,16 +630,16 @@ }, 'WEBLAYER_IMPL_SKEW_TESTS_NTH_MINUS_ONE_MILESTONE': { 'args': [ + '--webview-apk-path=apks/SystemWebView.apk', '--test-runner-outdir', '.', '--client-outdir', '.', - '--test-expectations', - '../../weblayer/browser/android/javatests/skew/expectations.txt', - '--webview-apk-path=apks/SystemWebView.apk', '--implementation-outdir', '../../weblayer_instrumentation_test_M102/out/Release', - '--impl-version=102' + '--test-expectations', + '../../weblayer/browser/android/javatests/skew/expectations.txt', + '--impl-version=102', ], 'identifier': 'with_impl_from_102', 'swarming': { @@ -647,10 +647,10 @@ { 'cipd_package': 'chromium/testing/weblayer-x86', 'location': 'weblayer_instrumentation_test_M102', - 'revision': 'version:102.0.5005.74' + 'revision': 'version:102.0.5005.75', } - ] - } + ], + }, }, 'WEBLAYER_IMPL_SKEW_TESTS_NTH_MINUS_TWO_MILESTONE': { 'args': [ @@ -774,16 +774,16 @@ }, 'WEBLAYER_CLIENT_SKEW_TESTS_NTH_MINUS_ONE_MILESTONE': { 'args': [ + '--webview-apk-path=apks/SystemWebView.apk', '--test-runner-outdir', '.', + '--client-outdir', + '../../weblayer_instrumentation_test_M102/out/Release', '--implementation-outdir', '.', '--test-expectations', '../../weblayer/browser/android/javatests/skew/expectations.txt', - '--webview-apk-path=apks/SystemWebView.apk', - '--client-outdir', - '../../weblayer_instrumentation_test_M102/out/Release', - '--client-version=102' + '--client-version=102', ], 'identifier': 'with_client_from_102', 'swarming': { @@ -791,10 +791,10 @@ { 'cipd_package': 'chromium/testing/weblayer-x86', 'location': 'weblayer_instrumentation_test_M102', - 'revision': 'version:102.0.5005.74' + 'revision': 'version:102.0.5005.75', } - ] - } + ], + }, }, 'WEBLAYER_CLIENT_SKEW_TESTS_NTH_MINUS_TWO_MILESTONE': { 'args': [
diff --git a/testing/variations/fieldtrial_testing_config.json b/testing/variations/fieldtrial_testing_config.json index e8897a91..df3d2d18 100644 --- a/testing/variations/fieldtrial_testing_config.json +++ b/testing/variations/fieldtrial_testing_config.json
@@ -8757,6 +8757,33 @@ ] } ], + "WebApkServiceWorkerRemovalStudy": [ + { + "platforms": [ + "android" + ], + "experiments": [ + { + "name": "Enabled_SkipAllServiceWorkerChecks_20220513", + "enable_features": [ + "SkipServiceWorkerCheckAll" + ], + "disable_features": [ + "SkipServiceWorkerCheckInstallOnly" + ] + }, + { + "name": "Enabled_SkipServiceWorkerChecksInstallOnly", + "enable_features": [ + "SkipServiceWorkerCheckInstallOnly" + ], + "disable_features": [ + "SkipServiceWorkerCheckAll" + ] + } + ] + } + ], "WebFeedsMVP": [ { "platforms": [
diff --git a/third_party/blink/common/client_hints/client_hints.cc b/third_party/blink/common/client_hints/client_hints.cc index 373c3d9..fb72470 100644 --- a/third_party/blink/common/client_hints/client_hints.cc +++ b/third_party/blink/common/client_hints/client_hints.cc
@@ -76,8 +76,6 @@ mojom::PermissionsPolicyFeature::kClientHintUAFull}, {network::mojom::WebClientHintsType::kUAWoW64, mojom::PermissionsPolicyFeature::kClientHintUAWoW64}, - {network::mojom::WebClientHintsType::kPartitionedCookies, - mojom::PermissionsPolicyFeature::kClientHintPartitionedCookies}, {network::mojom::WebClientHintsType::kSaveData, mojom::PermissionsPolicyFeature::kClientHintSaveData}, };
diff --git a/third_party/blink/common/client_hints/client_hints_unittest.cc b/third_party/blink/common/client_hints/client_hints_unittest.cc index 046a865a..171a887 100644 --- a/third_party/blink/common/client_hints/client_hints_unittest.cc +++ b/third_party/blink/common/client_hints/client_hints_unittest.cc
@@ -33,8 +33,7 @@ "sec-ch-prefers-color-scheme", "sec-ch-ua-bitness", "sec-ch-ua-reduced", "sec-ch-viewport-height", "sec-ch-device-memory", "sec-ch-dpr", "sec-ch-width", "sec-ch-viewport-width", - "sec-ch-ua-full-version-list", "sec-ch-ua-full", "sec-ch-ua-wow64", - "sec-ch-partitioned-cookies")); + "sec-ch-ua-full-version-list", "sec-ch-ua-full", "sec-ch-ua-wow64")); } // Checks that the removed header list includes legacy headers but not the @@ -54,6 +53,6 @@ "sec-ch-ua-bitness", "sec-ch-ua-reduced", "sec-ch-viewport-height", "sec-ch-device-memory", "sec-ch-dpr", "sec-ch-width", "sec-ch-viewport-width", "sec-ch-ua-full-version-list", - "sec-ch-ua-full", "sec-ch-ua-wow64", "sec-ch-partitioned-cookies")); + "sec-ch-ua-full", "sec-ch-ua-wow64")); } } // namespace blink
diff --git a/third_party/blink/common/client_hints/enabled_client_hints.cc b/third_party/blink/common/client_hints/enabled_client_hints.cc index 075e1382..b8bd0e4 100644 --- a/third_party/blink/common/client_hints/enabled_client_hints.cc +++ b/third_party/blink/common/client_hints/enabled_client_hints.cc
@@ -6,7 +6,6 @@ #include "base/feature_list.h" #include "base/time/time.h" -#include "net/base/features.h" #include "net/http/http_response_headers.h" #include "services/network/public/cpp/client_hints.h" #include "third_party/blink/public/common/features.h" @@ -85,10 +84,6 @@ features::kClientHintsViewportWidth_DEPRECATED)) return true; break; - case WebClientHintsType::kPartitionedCookies: - if (!base::FeatureList::IsEnabled(net::features::kPartitionedCookies)) - return true; - break; case WebClientHintsType::kSaveData: if (!base::FeatureList::IsEnabled(features::kClientHintsSaveData)) return true; @@ -160,10 +155,6 @@ enabled = IsOriginTrialEnabled(url, third_party_url, response_headers, "SendFullUserAgentAfterReduction"); } - if (enabled && type == WebClientHintsType::kPartitionedCookies) { - enabled = IsOriginTrialEnabled(url, third_party_url, response_headers, - "PartitionedCookies"); - } SetIsEnabled(type, enabled); }
diff --git a/third_party/blink/common/client_hints/enabled_client_hints_unittest.cc b/third_party/blink/common/client_hints/enabled_client_hints_unittest.cc index 26101764..ef0438f2 100644 --- a/third_party/blink/common/client_hints/enabled_client_hints_unittest.cc +++ b/third_party/blink/common/client_hints/enabled_client_hints_unittest.cc
@@ -7,7 +7,6 @@ #include "absl/types/optional.h" #include "base/memory/scoped_refptr.h" #include "base/test/scoped_feature_list.h" -#include "net/base/features.h" #include "net/http/http_response_headers.h" #include "testing/gmock/include/gmock/gmock.h" #include "testing/gtest/include/gtest/gtest.h" @@ -322,128 +321,4 @@ WebClientHintsType::kUAFullVersionList)); } -// TODO(crbug.com/1296161): Delete this when partitioned cookies Origin Trial is -// over. -class PartitionedCookiesEnabledClientHintsTest - : public testing::TestWithParam<bool> { - protected: - PartitionedCookiesEnabledClientHintsTest() - : response_headers_(base::MakeRefCounted<net::HttpResponseHeaders>("")) { - TrialTokenValidator::SetOriginTrialPolicyGetter( - base::BindRepeating([](OriginTrialPolicy* policy) { return policy; }, - base::Unretained(&policy_))); - policy_.SetPublicKeys({kTestPublicKey}); - } - - void SetUp() override { - std::vector<base::Feature> enabled_features = { - blink::features::kUserAgentClientHint, - blink::features::kUserAgentClientHintFullVersionList}; - std::vector<base::Feature> disabled_features = { - blink::features::kPrefersColorSchemeClientHintHeader}; - if (PartitionedCookiesEnabled()) { - enabled_features.push_back(net::features::kPartitionedCookies); - } else { - disabled_features.push_back(net::features::kPartitionedCookies); - } - - scoped_feature_list_.InitWithFeatures(enabled_features, disabled_features); - testing::TestWithParam<bool>::SetUp(); - } - - bool PartitionedCookiesEnabled() { return GetParam(); } - - void VerifyClientHintEnabledWithOriginTrialToken( - const std::string& token, - const GURL* third_party_url, - const WebClientHintsType client_hint_type, - bool expected_client_hint_enabled) { - VerifyClientHintEnabledWithOriginTrialTokenInner( - response_headers_.get(), token, third_party_url, client_hint_type, - expected_client_hint_enabled); - } - - private: - base::test::ScopedFeatureList scoped_feature_list_; - scoped_refptr<net::HttpResponseHeaders> response_headers_; - TestOriginTrialPolicy policy_; -}; - -INSTANTIATE_TEST_SUITE_P(/* no label */, - PartitionedCookiesEnabledClientHintsTest, - testing::Bool()); - -TEST_P(PartitionedCookiesEnabledClientHintsTest, - EnabledPartitionedCookiesClientHintWithValidOriginTrialToken) { - // Generated by running (in tools/origin_trials): - // generate_token.py https://127.0.0.1:44444 PartitionedCookies - // --expire-timestamp=2000000000 - // - // The Origin Trial token expires in 2033. Generate a new token by then, or - // find a better way to re-generate a test trial token. - static constexpr char kValidOriginTrialToken[] = - "A4s/" - "iPKfhEfgqQIIuz4zLuCpONpXOuYyJFBhBx1MfgS1aNhFujyhsg4lkfTRfjzQCI3aUbMwtNm2" - "5elLTR4UIgAAAABceyJvcmlnaW4iOiAiaHR0cHM6Ly8xMjcuMC4wLjE6NDQ0NDQiLCAiZmVh" - "dHVyZSI6ICJQYXJ0aXRpb25lZENvb2tpZXMiLCAiZXhwaXJ5IjogMjAwMDAwMDAwMH0="; - - VerifyClientHintEnabledWithOriginTrialToken( - kValidOriginTrialToken, - /*third_party_url=*/nullptr, WebClientHintsType::kPartitionedCookies, - /*expected_client_hint_enabled=*/PartitionedCookiesEnabled()); -} - -TEST_P(PartitionedCookiesEnabledClientHintsTest, - EnabledPartitionedCookiesClientHintWithInvalidOriginTrialToken) { - // Changed the first character of the token in the last test. - static constexpr char kValidOriginTrialToken[] = - "B4s/" - "iPKfhEfgqQIIuz4zLuCpONpXOuYyJFBhBx1MfgS1aNhFujyhsg4lkfTRfjzQCI3aUbMwtNm2" - "5elLTR4UIgAAAABceyJvcmlnaW4iOiAiaHR0cHM6Ly8xMjcuMC4wLjE6NDQ0NDQiLCAiZmVh" - "dHVyZSI6ICJQYXJ0aXRpb25lZENvb2tpZXMiLCAiZXhwaXJ5IjogMjAwMDAwMDAwMH0="; - - VerifyClientHintEnabledWithOriginTrialToken( - kValidOriginTrialToken, - /*third_party_url=*/nullptr, WebClientHintsType::kPartitionedCookies, - /*expected_client_hint_enabled=*/false); -} - -TEST_P(PartitionedCookiesEnabledClientHintsTest, - EnabledPartitionedCookiesClientHintWithValidThirdPartyOriginTrialToken) { - // Generated by running (in tools/origin_trials): - // generate_token.py https://127.0.0.1:44445 PartitionedCookies - // --expire-timestamp=2000000000 --is-third-party - // - // The Origin Trial token expires in 2033. Generate a new token by then, or - // find a better way to re-generate a test trial token. - static constexpr char kValidThirdPartyOriginTrialToken[] = - "A2VMEbGkZuIokMW5yBD0YFxwr8cyNw8iqteLIH7mv2bbKdoyIe4IFNC9G/" - "Fk7sfN5gcwcwtSJEYsMp2e6ol6cQgAAAByeyJvcmlnaW4iOiAiaHR0cHM6Ly8xMjcuMC4wLj" - "E6NDQ0NDUiLCAiZmVhdHVyZSI6ICJQYXJ0aXRpb25lZENvb2tpZXMiLCAiZXhwaXJ5IjogMj" - "AwMDAwMDAwMCwgImlzVGhpcmRQYXJ0eSI6IHRydWV9"; - - const GURL third_party_url = GURL(kThirdPartyOriginUrl); - VerifyClientHintEnabledWithOriginTrialToken( - kValidThirdPartyOriginTrialToken, &third_party_url, - WebClientHintsType::kPartitionedCookies, - /*expected_client_hint_enabled=*/PartitionedCookiesEnabled()); -} - -TEST_P( - PartitionedCookiesEnabledClientHintsTest, - EnabledPartitionedCookiesClientHintWithInvalidThirdPartyOriginTrialToken) { - // Changed the first character of the token in the last test. - static constexpr char kValidThirdPartyOriginTrialToken[] = - "B2VMEbGkZuIokMW5yBD0YFxwr8cyNw8iqteLIH7mv2bbKdoyIe4IFNC9G/" - "Fk7sfN5gcwcwtSJEYsMp2e6ol6cQgAAAByeyJvcmlnaW4iOiAiaHR0cHM6Ly8xMjcuMC4wLj" - "E6NDQ0NDUiLCAiZmVhdHVyZSI6ICJQYXJ0aXRpb25lZENvb2tpZXMiLCAiZXhwaXJ5IjogMj" - "AwMDAwMDAwMCwgImlzVGhpcmRQYXJ0eSI6IHRydWV9"; - - const GURL third_party_url = GURL(kThirdPartyOriginUrl); - VerifyClientHintEnabledWithOriginTrialToken( - kValidThirdPartyOriginTrialToken, &third_party_url, - WebClientHintsType::kPartitionedCookies, - /*expected_client_hint_enabled=*/false); -} - } // namespace blink
diff --git a/third_party/blink/common/features.cc b/third_party/blink/common/features.cc index a4d753a..c9a9913 100644 --- a/third_party/blink/common/features.cc +++ b/third_party/blink/common/features.cc
@@ -1357,9 +1357,6 @@ const base::Feature kElementSuperRareData{"ElementSuperRareData", base::FEATURE_DISABLED_BY_DEFAULT}; -const base::Feature kClientHintsPartitionedCookies{ - "ClientHintsPartitionedCookies", base::FEATURE_DISABLED_BY_DEFAULT}; - const base::Feature kDurableClientHintsCache{"DurableClientHintsCache", base::FEATURE_ENABLED_BY_DEFAULT};
diff --git a/third_party/blink/public/common/features.h b/third_party/blink/public/common/features.h index 98f73c6..a5ee71e 100644 --- a/third_party/blink/public/common/features.h +++ b/third_party/blink/public/common/features.h
@@ -659,6 +659,9 @@ BLINK_COMMON_EXPORT extern const base::Feature kElementSuperRareData; BLINK_COMMON_EXPORT extern const base::Feature kClientHintsPartitionedCookies; +BLINK_COMMON_EXPORT extern const base::Feature kScaleTileMemoryLimit; +BLINK_COMMON_EXPORT +extern const base::FeatureParam<double> kScaleTileMemoryLimitFactor; // If enabled, the client hints cache will be loaded on browser restarts. BLINK_COMMON_EXPORT extern const base::Feature kDurableClientHintsCache;
diff --git a/third_party/blink/public/devtools_protocol/browser_protocol.pdl b/third_party/blink/public/devtools_protocol/browser_protocol.pdl index e889a0e..2f1befd 100644 --- a/third_party/blink/public/devtools_protocol/browser_protocol.pdl +++ b/third_party/blink/public/devtools_protocol/browser_protocol.pdl
@@ -6999,7 +6999,6 @@ ch-device-memory ch-downlink ch-ect - ch-partitioned-cookies ch-prefers-color-scheme ch-rtt ch-save-data
diff --git a/third_party/blink/public/mojom/permissions_policy/permissions_policy_feature.mojom b/third_party/blink/public/mojom/permissions_policy/permissions_policy_feature.mojom index c9232f4..b67a678 100644 --- a/third_party/blink/public/mojom/permissions_policy/permissions_policy_feature.mojom +++ b/third_party/blink/public/mojom/permissions_policy/permissions_policy_feature.mojom
@@ -193,21 +193,6 @@ // Client Hint for Sec-CH-UA-WoW64. kClientHintUAWoW64 = 95, - // Client hint for indicating that the client supports the Partitioned cookie - // attribute. - // - // The `Sec-CH-Partitioned-Cookies` header field is a temporary client hint, - // which will only be sent in the presence of a valid Origin Trial token. It - // was introduced to enable safely experimenting with cookies set with the - // Partitioned attribute. - // - // See https://chromestatus.com/feature/5179189105786880 for the Partitioned - // cookie attribute (a.k.a. Cookies Having Independent Partitioned State, - // CHIPS) Chrome feature and the explainer at https://github.com/WICG/CHIPS - // for details about the design of the Partitioned attribute and partitioned - // cookies. - kClientHintPartitionedCookies = 96, - // "browsing-topics" permissions policy that controls the use of Topics API. // https://github.com/jkarlin/topics kBrowsingTopics = 97,
diff --git a/third_party/blink/public/mojom/use_counter/metrics/web_feature.mojom b/third_party/blink/public/mojom/use_counter/metrics/web_feature.mojom index 464dc967c..59d310b 100644 --- a/third_party/blink/public/mojom/use_counter/metrics/web_feature.mojom +++ b/third_party/blink/public/mojom/use_counter/metrics/web_feature.mojom
@@ -3478,7 +3478,7 @@ kV8UDPSocket_RemotePort_AttributeGetter = 4157, kV8UDPSocket_Writable_AttributeGetter = 4158, kAbortSignalTimeout = 4159, - kClientHintsPartitionedCookies = 4160, + kOBSOLETE_ClientHintsPartitionedCookies = 4160, kV8Document_Prerendering_AttributeGetter = 4161, kV8Document_Onprerenderingchange_AttributeGetter = 4162, kV8Document_Onprerenderingchange_AttributeSetter = 4163, @@ -3556,6 +3556,11 @@ kV8PaymentInstruments_Set_Method = 4235, kPerformanceMeasureFindExistingName = 4236, kFlexboxNewAbsPos = 4237, + kScriptSchedulingType_Defer = 4238, + kScriptSchedulingType_ParserBlocking = 4239, + kScriptSchedulingType_ParserBlockingInline = 4240, + kScriptSchedulingType_InOrder = 4241, + kScriptSchedulingType_Async = 4242, // Add new features immediately above this line. Don't change assigned // numbers of any item, and don't reuse removed slots.
diff --git a/third_party/blink/public/platform/web_url_loader_client.h b/third_party/blink/public/platform/web_url_loader_client.h index 9d430c5..b111c2f5 100644 --- a/third_party/blink/public/platform/web_url_loader_client.h +++ b/third_party/blink/public/platform/web_url_loader_client.h
@@ -110,11 +110,13 @@ // will be generated in devtools console if this flag is set to true. // TODO(crbug.com/798625): use different callback for subresources // with responses blocked due to document protection. - virtual void DidFinishLoading(base::TimeTicks finish_time, - int64_t total_encoded_data_length, - int64_t total_encoded_body_length, - int64_t total_decoded_body_length, - bool should_report_corb_blocking) {} + virtual void DidFinishLoading( + base::TimeTicks finish_time, + int64_t total_encoded_data_length, + int64_t total_encoded_body_length, + int64_t total_decoded_body_length, + bool should_report_corb_blocking, + absl::optional<bool> pervasive_payload_requested = absl::nullopt) {} // Called when the load completes with an error. // |finish_time| indicating the time in which the response failed.
diff --git a/third_party/blink/public/platform/web_url_response.h b/third_party/blink/public/platform/web_url_response.h index bce15a98..1b0e7314 100644 --- a/third_party/blink/public/platform/web_url_response.h +++ b/third_party/blink/public/platform/web_url_response.h
@@ -284,6 +284,9 @@ BLINK_PLATFORM_EXPORT void SetWasFetchedViaCache(bool); BLINK_PLATFORM_EXPORT void SetArrivalTimeAtRenderer(base::TimeTicks arrival); + BLINK_PLATFORM_EXPORT void SetHasPartitionedCookie( + bool has_partitioned_cookie); + #if INSIDE_BLINK protected: // Permit subclasses to set arbitrary ResourceResponse pointer as
diff --git a/third_party/blink/renderer/core/dom/document.cc b/third_party/blink/renderer/core/dom/document.cc index 8e3b874..37035de9 100644 --- a/third_party/blink/renderer/core/dom/document.cc +++ b/third_party/blink/renderer/core/dom/document.cc
@@ -8416,6 +8416,13 @@ Document::PendingJavascriptUrl::~PendingJavascriptUrl() = default; +void Document::CheckPartitionedCookiesOriginTrial( + const ResourceResponse& response) { + // if (!cookie_jar_) + // return; + cookie_jar_->CheckPartitionedCookiesOriginTrial(response); +} + template class CORE_TEMPLATE_EXPORT Supplement<Document>; } // namespace blink
diff --git a/third_party/blink/renderer/core/dom/document.h b/third_party/blink/renderer/core/dom/document.h index f171d4b..c2503dfd 100644 --- a/third_party/blink/renderer/core/dom/document.h +++ b/third_party/blink/renderer/core/dom/document.h
@@ -205,6 +205,7 @@ class RenderBlockingResourceManager; class ResizeObserver; class ResourceFetcher; +class ResourceResponse; class RootScrollerController; class SVGDocumentExtensions; class SVGUseElement; @@ -1874,6 +1875,9 @@ void WriteIntoTrace(perfetto::TracedValue ctx) const; + // TODO(https://crbug.com/1296161): Delete this function. + void CheckPartitionedCookiesOriginTrial(const ResourceResponse& response); + protected: void ClearXMLVersion() { xml_version_ = String(); }
diff --git a/third_party/blink/renderer/core/html/html_script_element.cc b/third_party/blink/renderer/core/html/html_script_element.cc index 679fa6d..4f7e637 100644 --- a/third_party/blink/renderer/core/html/html_script_element.cc +++ b/third_party/blink/renderer/core/html/html_script_element.cc
@@ -120,9 +120,8 @@ Node::InsertionNotificationRequest HTMLScriptElement::InsertedInto( ContainerNode& insertion_point) { if (insertion_point.isConnected() && HasSourceAttribute() && - ScriptLoader::GetScriptTypeAtPrepare( - TypeAttributeValue(), LanguageAttributeValue(), - ScriptLoader::kDisallowLegacyTypeInTypeAttribute) == + ScriptLoader::GetScriptTypeAtPrepare(TypeAttributeValue(), + LanguageAttributeValue()) == ScriptLoader::ScriptTypeAtPrepare::kInvalid) { UseCounter::Count(GetDocument(), WebFeature::kScriptElementWithInvalidTypeHasSrc);
diff --git a/third_party/blink/renderer/core/html/parser/html_preload_scanner.cc b/third_party/blink/renderer/core/html/parser/html_preload_scanner.cc index 3545884a..7402fac5 100644 --- a/third_party/blink/renderer/core/html/parser/html_preload_scanner.cc +++ b/third_party/blink/renderer/core/html/parser/html_preload_scanner.cc
@@ -686,9 +686,8 @@ return false; if (Match(tag_impl_, html_names::kScriptTag)) { ScriptLoader::ScriptTypeAtPrepare script_type = - ScriptLoader::GetScriptTypeAtPrepare( - type_attribute_value_, language_attribute_value_, - ScriptLoader::kDisallowLegacyTypeInTypeAttribute); + ScriptLoader::GetScriptTypeAtPrepare(type_attribute_value_, + language_attribute_value_); switch (script_type) { case ScriptLoader::ScriptTypeAtPrepare::kInvalid: return false; @@ -1000,8 +999,7 @@ if (type_attribute && ScriptLoader::GetScriptTypeAtPrepare( type_attribute->Value(), - /*language_attribute_value=*/g_empty_atom, - ScriptLoader::kDisallowLegacyTypeInTypeAttribute) == + /*language_attribute_value=*/g_empty_atom) == ScriptLoader::ScriptTypeAtPrepare::kWebBundle) { in_script_web_bundle_ = true; }
diff --git a/third_party/blink/renderer/core/inspector/inspector_page_agent.cc b/third_party/blink/renderer/core/inspector/inspector_page_agent.cc index f57308c9..ef3e2cb 100644 --- a/third_party/blink/renderer/core/inspector/inspector_page_agent.cc +++ b/third_party/blink/renderer/core/inspector/inspector_page_agent.cc
@@ -919,53 +919,43 @@ return Decimal::FromString(a) < Decimal::FromString(b); }); + // Throughout this method, + // `ExecuteScriptPolicy::kExecuteScriptWhenScriptsDisabled` is used because + // `inspector-protocol/page/add-script-to-evaluate-on-load-disabled-js.js` + // requires that the scripts here should be evaluated on pages with scripting + // disabled. + for (const WTF::String& key : keys) { - const String source = scripts_to_evaluate_on_load_.Get(key); - const String world_name = worlds_to_evaluate_on_load_.Get(key); - const bool include_command_line_api = - include_command_line_api_for_scripts_to_evaluate_on_load_.Get(key); auto* window = frame->DomWindow(); + v8::HandleScope handle_scope(window->GetIsolate()); + + ScriptState* script_state = nullptr; + const String world_name = worlds_to_evaluate_on_load_.Get(key); if (world_name.IsEmpty()) { - if (include_command_line_api) { - v8::HandleScope handle_scope(window->GetIsolate()); - ScriptState* script_state = - ToScriptStateForMainWorld(window->GetFrame()); - auto scope = v8_session_->initializeCommandLineAPIScope( - v8_inspector::V8ContextInfo::executionContextId( - script_state->GetContext())); - DCHECK(scope); - ClassicScript::CreateUnspecifiedScript(source)->RunScript( - window, ExecuteScriptPolicy::kExecuteScriptWhenScriptsDisabled); - } else { - ClassicScript::CreateUnspecifiedScript(source)->RunScript( - window, ExecuteScriptPolicy::kExecuteScriptWhenScriptsDisabled); - } - continue; - } - - scoped_refptr<DOMWrapperWorld> world = EnsureDOMWrapperWorld( - frame, world_name, true /* grant_universal_access */); - if (!world) - continue; - - // Note: An error event in an isolated world will never be dispatched to - // a foreign world. - v8::HandleScope handle_scope(V8PerIsolateData::MainThreadIsolate()); - if (include_command_line_api) { - ScriptState* script_state = ToScriptState( + script_state = ToScriptStateForMainWorld(window->GetFrame()); + } else if (scoped_refptr<DOMWrapperWorld> world = EnsureDOMWrapperWorld( + frame, world_name, true /* grant_universal_access */)) { + script_state = ToScriptState( window->GetFrame(), *DOMWrapperWorld::EnsureIsolatedWorld(ToIsolate(window->GetFrame()), world->GetWorldId())); - auto scope = v8_session_->initializeCommandLineAPIScope( + } + if (!script_state) + continue; + + std::unique_ptr<v8_inspector::V8InspectorSession::CommandLineAPIScope> + scope; + if (include_command_line_api_for_scripts_to_evaluate_on_load_.Get(key)) { + scope = v8_session_->initializeCommandLineAPIScope( v8_inspector::V8ContextInfo::executionContextId( script_state->GetContext())); DCHECK(scope); - ClassicScript::CreateUnspecifiedScript(source) - ->RunScriptInIsolatedWorldAndReturnValue(window, world->GetWorldId()); - } else { - ClassicScript::CreateUnspecifiedScript(source) - ->RunScriptInIsolatedWorldAndReturnValue(window, world->GetWorldId()); } + ClassicScript::CreateUnspecifiedScript( + scripts_to_evaluate_on_load_.Get(key)) + ->RunScriptOnScriptState( + script_state, + ExecuteScriptPolicy::kExecuteScriptWhenScriptsDisabled); } if (!script_to_evaluate_on_load_once_.IsEmpty()) {
diff --git a/third_party/blink/renderer/core/loader/base_fetch_context.cc b/third_party/blink/renderer/core/loader/base_fetch_context.cc index 94a6c6c..9e44eb1 100644 --- a/third_party/blink/renderer/core/loader/base_fetch_context.cc +++ b/third_party/blink/renderer/core/loader/base_fetch_context.cc
@@ -480,17 +480,6 @@ prefers_color_scheme.value()); } - if (ShouldSendClientHint( - ClientHintsMode::kStandard, policy, resource_origin, is_1p_origin, - network::mojom::blink::WebClientHintsType::kPartitionedCookies, - hints_preferences)) { - request.SetHttpHeaderField( - network::GetClientHintToNameMap() - .at(network::mojom::blink::WebClientHintsType::kPartitionedCookies) - .c_str(), - SerializeBoolHeader(true)); - } - if (ShouldSendClientHint(ClientHintsMode::kStandard, policy, resource_origin, is_1p_origin, network::mojom::blink::WebClientHintsType::kSaveData, @@ -708,14 +697,12 @@ base::FeatureList::IsEnabled(features::kAllowClientHintsToThirdParty)) { origin_ok = true; } else { - // For subresource requests, if the parent frame has Sec-CH-UA-Reduced, - // Sec-CH-UA-Full, or Sec-CH-Partitioned-Cookies, then send the hint in the - // fetch request, regardless of the permissions policy. + // For subresource requests, if the parent frame has Sec-CH-UA-Reduced or + // Sec-CH-UA-Full then send the hint in the fetch request, regardless of the + // permissions policy. origin_ok = type == network::mojom::blink::WebClientHintsType::kUAReduced || type == network::mojom::blink::WebClientHintsType::kFullUserAgent || - type == - network::mojom::blink::WebClientHintsType::kPartitionedCookies || (policy && policy->IsFeatureEnabledForOrigin( GetClientHintToPolicyFeatureMap().at(type), resource_origin));
diff --git a/third_party/blink/renderer/core/loader/cookie_jar.cc b/third_party/blink/renderer/core/loader/cookie_jar.cc index 05c4d24..ddc73a1 100644 --- a/third_party/blink/renderer/core/loader/cookie_jar.cc +++ b/third_party/blink/renderer/core/loader/cookie_jar.cc
@@ -4,9 +4,15 @@ #include "third_party/blink/renderer/core/loader/cookie_jar.h" +#include "base/feature_list.h" #include "base/metrics/histogram_functions.h" #include "base/strings/strcat.h" +#include "net/base/features.h" +#include "net/cookies/parsed_cookie.h" #include "third_party/blink/public/common/browser_interface_broker_proxy.h" +#include "third_party/blink/public/common/origin_trials/trial_token.h" +#include "third_party/blink/public/common/origin_trials/trial_token_result.h" +#include "third_party/blink/public/common/origin_trials/trial_token_validator.h" #include "third_party/blink/renderer/core/dom/document.h" #include "third_party/blink/renderer/core/execution_context/execution_context.h" #include "third_party/blink/renderer/core/frame/local_frame.h" @@ -32,6 +38,30 @@ return c == '\0' || c == '\r' || c == '\n'; } +bool ValidPartitionedCookiesOriginTrial(const ResourceResponse& response) { + // This should never be called if partitioned cookies are disabled. + DCHECK(base::FeatureList::IsEnabled(net::features::kPartitionedCookies)); + + if (!response.HttpHeaderFields().Contains("origin-trial")) + return false; + + blink::TrialTokenValidator validator; + base::Time now(base::Time::Now()); + + GURL url(response.ResponseUrl()); + if (!validator.IsTrialPossibleOnOrigin(url)) + return false; + + url::Origin origin = url::Origin::Create(url); + url::Origin third_party_origins[] = {origin}; + StringUTF8Adaptor token_adaptor(response.HttpHeaderField("origin-trial")); + TrialTokenResult result = validator.ValidateToken( + token_adaptor.AsStringPiece(), origin, third_party_origins, now); + + return result.Status() == blink::OriginTrialTokenStatus::kSuccess && + result.ParsedToken()->feature_name() == "PartitionedCookies"; +} + } // namespace CookieJar::CookieJar(blink::Document* document) @@ -115,4 +145,19 @@ return false; } +void CookieJar::CheckPartitionedCookiesOriginTrial( + const ResourceResponse& response) { + if (!response.HasPartitionedCookie() || + !base::FeatureList::IsEnabled(net::features::kPartitionedCookies)) { + return; + } + if (!ValidPartitionedCookiesOriginTrial(response)) { + base::ElapsedTimer timer; + bool requested = RequestRestrictedCookieManagerIfNeeded(); + LogCookieHistogram("Blink.CookiesEnabledTime.", requested, + timer.Elapsed()); + backend_->ConvertPartitionedCookiesToUnpartitioned(response.ResponseUrl()); + } +} + } // namespace blink
diff --git a/third_party/blink/renderer/core/loader/cookie_jar.h b/third_party/blink/renderer/core/loader/cookie_jar.h index e91c441..ecaa64d 100644 --- a/third_party/blink/renderer/core/loader/cookie_jar.h +++ b/third_party/blink/renderer/core/loader/cookie_jar.h
@@ -8,6 +8,7 @@ #include "services/network/public/mojom/restricted_cookie_manager.mojom-blink.h" #include "third_party/blink/renderer/platform/heap/garbage_collected.h" +#include "third_party/blink/renderer/platform/loader/fetch/resource_response.h" #include "third_party/blink/renderer/platform/mojo/heap_mojo_remote.h" #include "third_party/blink/renderer/platform/wtf/text/wtf_string.h" @@ -27,6 +28,16 @@ mojo::PendingRemote<network::mojom::blink::RestrictedCookieManager> cookie_manager); + // This function checks subresource requests for the partitioned cookies + // origin trial. We only consider requests that: + // - have a Set-Cookie header + // - have Partitioned in the cookie line + // If both of these conditions are met, we check if the response contains an + // Origin-Trial header with a valid token. If it does not, we revert that + // URL's partitioned cookies to unpartitioned. + // TODO(https://crbug.com/1296161): Delete this function. + void CheckPartitionedCookiesOriginTrial(const ResourceResponse& response); + private: bool RequestRestrictedCookieManagerIfNeeded();
diff --git a/third_party/blink/renderer/core/loader/frame_client_hints_preferences_context.cc b/third_party/blink/renderer/core/loader/frame_client_hints_preferences_context.cc index 5755fb6..dfa347d 100644 --- a/third_party/blink/renderer/core/loader/frame_client_hints_preferences_context.cc +++ b/third_party/blink/renderer/core/loader/frame_client_hints_preferences_context.cc
@@ -72,8 +72,6 @@ WebFeature::kClientHintsUAFull}, {network::mojom::WebClientHintsType::kUAWoW64, WebFeature::kClientHintsUAWoW64}, - {network::mojom::WebClientHintsType::kPartitionedCookies, - WebFeature::kClientHintsPartitionedCookies}, {network::mojom::WebClientHintsType::kSaveData, WebFeature::kClientHintsSaveData}, };
diff --git a/third_party/blink/renderer/core/loader/resource_load_observer_for_frame.cc b/third_party/blink/renderer/core/loader/resource_load_observer_for_frame.cc index 02a7f71e..c7d6ed6f 100644 --- a/third_party/blink/renderer/core/loader/resource_load_observer_for_frame.cc +++ b/third_party/blink/renderer/core/loader/resource_load_observer_for_frame.cc
@@ -315,6 +315,8 @@ // It is essential that inspector gets resource response BEFORE console. frame->Console().ReportResourceResponseReceived(document_loader_, identifier, response); + + document_->CheckPartitionedCookiesOriginTrial(response); } void ResourceLoadObserverForFrame::DidReceiveData(
diff --git a/third_party/blink/renderer/core/permissions_policy/permissions_policy_features.json5 b/third_party/blink/renderer/core/permissions_policy/permissions_policy_features.json5 index 26f3a5c..6ef895fe 100644 --- a/third_party/blink/renderer/core/permissions_policy/permissions_policy_features.json5 +++ b/third_party/blink/renderer/core/permissions_policy/permissions_policy_features.json5
@@ -109,11 +109,6 @@ permissions_policy_name: "ch-ect", }, { - name: "ClientHintPartitionedCookies", - permissions_policy_name: "ch-partitioned-cookies", - depends_on: ["PartitionedCookies"], - }, - { name: "ClientHintPrefersColorScheme", permissions_policy_name: "ch-prefers-color-scheme", },
diff --git a/third_party/blink/renderer/core/script/script_loader.cc b/third_party/blink/renderer/core/script/script_loader.cc index 53529e3f..2629a08 100644 --- a/third_party/blink/renderer/core/script/script_loader.cc +++ b/third_party/blink/renderer/core/script/script_loader.cc
@@ -174,18 +174,8 @@ namespace { // <specdef href="https://html.spec.whatwg.org/C/#prepare-a-script"> -bool IsValidClassicScriptTypeAndLanguage( - const String& type, - const String& language, - ScriptLoader::LegacyTypeSupport support_legacy_types) { - // FIXME: IsLegacySupportedJavaScriptLanguage() is not valid HTML5. It is used - // here to maintain backwards compatibility with existing web tests. The - // specific violations are: - // - Allowing type=javascript. type= should only support MIME types, such as - // text/javascript. - // - Allowing a different set of languages for language= and type=. language= - // supports Javascript 1.1 and 1.4-1.6, but type= does not. - +bool IsValidClassicScriptTypeAndLanguage(const String& type, + const String& language) { if (type.IsNull()) { // <spec step="8">the script element has no type attribute but it has a // language attribute and that attribute's value is the empty string, @@ -215,12 +205,6 @@ type.StripWhiteSpace())) { return true; } - - // Not spec'ed. - if (support_legacy_types == ScriptLoader::kAllowLegacyTypeInTypeAttribute && - MIMETypeRegistry::IsLegacySupportedJavaScriptLanguage(type)) { - return true; - } } return false; @@ -235,10 +219,8 @@ ScriptLoader::ScriptTypeAtPrepare ScriptLoader::GetScriptTypeAtPrepare( const String& type, - const String& language, - LegacyTypeSupport support_legacy_types) { - if (IsValidClassicScriptTypeAndLanguage(type, language, - support_legacy_types)) { + const String& language) { + if (IsValidClassicScriptTypeAndLanguage(type, language)) { // <spec step="8">... If the script block's type string is a JavaScript MIME // type essence match, the script's type is "classic". ...</spec> return ScriptTypeAtPrepare::kClassic; @@ -317,8 +299,7 @@ } // <specdef href="https://html.spec.whatwg.org/C/#prepare-a-script"> -bool ScriptLoader::PrepareScript(const TextPosition& script_start_position, - LegacyTypeSupport support_legacy_types) { +bool ScriptLoader::PrepareScript(const TextPosition& script_start_position) { // <spec step="1">If the script element is marked as having "already started", // then return. The script is not executed.</spec> if (already_started_) @@ -367,8 +348,7 @@ // <spec step="7">... Determine the script's type as follows: ...</spec> script_type_ = GetScriptTypeAtPrepare(element_->TypeAttributeValue(), - element_->LanguageAttributeValue(), - support_legacy_types); + element_->LanguageAttributeValue()); switch (GetScriptType()) { case ScriptTypeAtPrepare::kInvalid: @@ -1120,6 +1100,32 @@ scheduling_type); } + // Record usage histograms per page. + switch (scheduling_type) { + case ScriptSchedulingType::kDefer: + UseCounter::Count(element_->GetDocument(), + WebFeature::kScriptSchedulingType_Defer); + break; + case ScriptSchedulingType::kParserBlocking: + UseCounter::Count(element_->GetDocument(), + WebFeature::kScriptSchedulingType_ParserBlocking); + break; + case ScriptSchedulingType::kParserBlockingInline: + UseCounter::Count(element_->GetDocument(), + WebFeature::kScriptSchedulingType_ParserBlockingInline); + break; + case ScriptSchedulingType::kInOrder: + UseCounter::Count(element_->GetDocument(), + WebFeature::kScriptSchedulingType_InOrder); + break; + case ScriptSchedulingType::kAsync: + UseCounter::Count(element_->GetDocument(), + WebFeature::kScriptSchedulingType_Async); + break; + default: + break; + } + PendingScript* pending_script = prepared_pending_script_; prepared_pending_script_ = nullptr; pending_script->SetSchedulingType(scheduling_type);
diff --git a/third_party/blink/renderer/core/script/script_loader.h b/third_party/blink/renderer/core/script/script_loader.h index 368d588..ee98491 100644 --- a/third_party/blink/renderer/core/script/script_loader.h +++ b/third_party/blink/renderer/core/script/script_loader.h
@@ -53,11 +53,6 @@ const char* NameInHeapSnapshot() const override { return "ScriptLoader"; } String DebugName() const override { return "ScriptLoader"; } - enum LegacyTypeSupport { - kDisallowLegacyTypeInTypeAttribute, - kAllowLegacyTypeInTypeAttribute - }; - // Script type at the time of #prepare-a-script. Import maps are included here // but not in `mojom::blink::ScriptType` because import maps are handled // differently from ordinal scripts after PrepareScript(). @@ -72,8 +67,7 @@ static ScriptTypeAtPrepare GetScriptTypeAtPrepare( const String& type_attribute_value, - const String& language_attribute_value, - LegacyTypeSupport support_legacy_types); + const String& language_attribute_value); static bool BlockForNoModule(ScriptTypeAtPrepare, bool nomodule); @@ -82,8 +76,7 @@ // https://html.spec.whatwg.org/C/#prepare-a-script bool PrepareScript(const TextPosition& script_start_position = - TextPosition::MinimumPosition(), - LegacyTypeSupport = kDisallowLegacyTypeInTypeAttribute); + TextPosition::MinimumPosition()); // Gets a PendingScript for external script whose fetch is started in // FetchClassicScript()/FetchModuleScriptTree().
diff --git a/third_party/blink/renderer/platform/exported/web_url_response.cc b/third_party/blink/renderer/platform/exported/web_url_response.cc index 67fb80e4..d390aad 100644 --- a/third_party/blink/renderer/platform/exported/web_url_response.cc +++ b/third_party/blink/renderer/platform/exported/web_url_response.cc
@@ -490,4 +490,8 @@ WebURLResponse::WebURLResponse(ResourceResponse& r) : resource_response_(&r) {} +void WebURLResponse::SetHasPartitionedCookie(bool has_partitioned_cookie) { + resource_response_->SetHasPartitionedCookie(has_partitioned_cookie); +} + } // namespace blink
diff --git a/third_party/blink/renderer/platform/loader/allowed_by_nosniff.cc b/third_party/blink/renderer/platform/loader/allowed_by_nosniff.cc index 5ed1f8c..5cea934 100644 --- a/third_party/blink/renderer/platform/loader/allowed_by_nosniff.cc +++ b/third_party/blink/renderer/platform/loader/allowed_by_nosniff.cc
@@ -115,9 +115,10 @@ // we still wish to accept them (or log them using UseCounter, or add a // deprecation warning to the console). - if (mime_type.StartsWithIgnoringASCIICase("text/") && - MIMETypeRegistry::IsLegacySupportedJavaScriptLanguage( - mime_type.Substring(5))) { + if (EqualIgnoringASCIICase(mime_type, "text/javascript1.6") || + EqualIgnoringASCIICase(mime_type, "text/javascript1.7")) { + // We've been excluding these legacy values from UseCounter stats since + // before. return true; }
diff --git a/third_party/blink/renderer/platform/loader/fetch/resource_loader.cc b/third_party/blink/renderer/platform/loader/fetch/resource_loader.cc index 7793a02..ff63b3d0 100644 --- a/third_party/blink/renderer/platform/loader/fetch/resource_loader.cc +++ b/third_party/blink/renderer/platform/loader/fetch/resource_loader.cc
@@ -37,6 +37,7 @@ #include "base/metrics/histogram_functions.h" #include "base/metrics/histogram_macros.h" #include "mojo/public/cpp/bindings/pending_remote.h" +#include "services/metrics/public/cpp/metrics_utils.h" #include "services/metrics/public/cpp/mojo_ukm_recorder.h" #include "services/metrics/public/cpp/ukm_builders.h" #include "services/network/public/cpp/cross_origin_embedder_policy.h" @@ -1233,15 +1234,28 @@ 0, false); } -void ResourceLoader::DidFinishLoading(base::TimeTicks response_end_time, - int64_t encoded_data_length, - int64_t encoded_body_length, - int64_t decoded_body_length, - bool should_report_corb_blocking) { +void ResourceLoader::DidFinishLoading( + base::TimeTicks response_end_time, + int64_t encoded_data_length, + int64_t encoded_body_length, + int64_t decoded_body_length, + bool should_report_corb_blocking, + absl::optional<bool> pervasive_payload_requested) { resource_->SetEncodedDataLength(encoded_data_length); resource_->SetEncodedBodyLength(encoded_body_length); resource_->SetDecodedBodyLength(decoded_body_length); + if (pervasive_payload_requested.has_value()) { + auto* ukm_recorder = ukm::UkmRecorder::Get(); + ukm::SourceId ukm_source_id = + resource_->GetResourceRequest().GetUkmSourceId(); + ukm::builders::Network_CacheTransparency builder(ukm_source_id); + builder.SetFoundPervasivePayload(pervasive_payload_requested.value()); + builder.SetTotalBytesFetched( + ukm::GetExponentialBucketMinForBytes(encoded_data_length)); + builder.Record(ukm_recorder->Get()); + } + response_end_time_for_error_cases_ = response_end_time; if ((response_body_loader_ && !has_seen_end_of_body_ &&
diff --git a/third_party/blink/renderer/platform/loader/fetch/resource_loader.h b/third_party/blink/renderer/platform/loader/fetch/resource_loader.h index aa4f85e..0a76e066 100644 --- a/third_party/blink/renderer/platform/loader/fetch/resource_loader.h +++ b/third_party/blink/renderer/platform/loader/fetch/resource_loader.h
@@ -149,7 +149,9 @@ int64_t encoded_data_length, int64_t encoded_body_length, int64_t decoded_body_length, - bool should_report_corb_blocking) override; + bool should_report_corb_blocking, + absl::optional<bool> pervasive_payload_requested = + absl::nullopt) override; void DidFail(const WebURLError&, base::TimeTicks response_end_time, int64_t encoded_data_length,
diff --git a/third_party/blink/renderer/platform/loader/fetch/resource_response.h b/third_party/blink/renderer/platform/loader/fetch/resource_response.h index ceccaa9ac..9c6c8308 100644 --- a/third_party/blink/renderer/platform/loader/fetch/resource_response.h +++ b/third_party/blink/renderer/platform/loader/fetch/resource_response.h
@@ -431,6 +431,11 @@ request_include_credentials_ = request_include_credentials; } + bool HasPartitionedCookie() const { return has_partitioned_cookie_; } + void SetHasPartitionedCookie(bool has_partitioned_cookie) { + has_partitioned_cookie_ = has_partitioned_cookie; + } + private: void UpdateHeaderParsedState(const AtomicString& name); @@ -625,6 +630,10 @@ absl::optional<net::AuthChallengeInfo> auth_challenge_info_; bool emitted_extra_info_ = false; + + // See URLResponseHead.has_partitioned_cookie. + // TODO(https://crbug.com/1296161): Delete this field. + bool has_partitioned_cookie_ = false; }; } // namespace blink
diff --git a/third_party/blink/renderer/platform/loader/fetch/url_loader/web_url_loader.cc b/third_party/blink/renderer/platform/loader/fetch/url_loader/web_url_loader.cc index 2af5816..daf2621 100644 --- a/third_party/blink/renderer/platform/loader/fetch/url_loader/web_url_loader.cc +++ b/third_party/blink/renderer/platform/loader/fetch/url_loader/web_url_loader.cc
@@ -29,6 +29,7 @@ #include "base/time/time.h" #include "build/build_config.h" #include "mojo/public/cpp/bindings/pending_remote.h" +#include "net/base/features.h" #include "net/base/filename_util.h" #include "net/base/host_port_pair.h" #include "net/base/ip_endpoint.h" @@ -39,6 +40,7 @@ #include "net/cert/ct_sct_to_string.h" #include "net/cert/x509_certificate.h" #include "net/cert/x509_util.h" +#include "net/cookies/parsed_cookie.h" #include "net/http/http_request_headers.h" #include "net/http/http_response_headers.h" #include "net/ssl/ssl_cipher_suite_names.h" @@ -620,7 +622,8 @@ } else { client_->DidFinishLoading(status.completion_time, total_transfer_size, encoded_body_size, status.decoded_body_length, - status.should_report_corb_blocking); + status.should_report_corb_blocking, + status.pervasive_payload_requested); } } } @@ -789,6 +792,8 @@ response->AddHttpHeaderField(WebString::FromLatin1(name), WebString::FromLatin1(value)); } + + response->SetHasPartitionedCookie(head.has_partitioned_cookie); } // static
diff --git a/third_party/blink/renderer/platform/loader/fetch/url_loader/web_url_loader_unittest.cc b/third_party/blink/renderer/platform/loader/fetch/url_loader/web_url_loader_unittest.cc index cd799e5..2d4f93a7 100644 --- a/third_party/blink/renderer/platform/loader/fetch/url_loader/web_url_loader_unittest.cc +++ b/third_party/blink/renderer/platform/loader/fetch/url_loader/web_url_loader_unittest.cc
@@ -244,11 +244,13 @@ NOTREACHED(); } - void DidFinishLoading(base::TimeTicks finishTime, - int64_t totalEncodedDataLength, - int64_t totalEncodedBodyLength, - int64_t totalDecodedBodyLength, - bool should_report_corb_blocking) override { + void DidFinishLoading( + base::TimeTicks finishTime, + int64_t totalEncodedDataLength, + int64_t totalEncodedBodyLength, + int64_t totalDecodedBodyLength, + bool should_report_corb_blocking, + absl::optional<bool> pervasive_payload_requested) override { EXPECT_TRUE(loader_); EXPECT_TRUE(did_receive_response_); EXPECT_FALSE(did_finish_);
diff --git a/third_party/blink/renderer/platform/network/mime/mime_type_registry.cc b/third_party/blink/renderer/platform/network/mime/mime_type_registry.cc index 201e939..f1478a9b 100644 --- a/third_party/blink/renderer/platform/network/mime/mime_type_registry.cc +++ b/third_party/blink/renderer/platform/network/mime/mime_type_registry.cc
@@ -118,32 +118,6 @@ return blink::IsJSONMimeType(ToLowerASCIIOrEmpty(mime_type)); } -bool MIMETypeRegistry::IsLegacySupportedJavaScriptLanguage( - const String& language) { - // Mozilla 1.8 accepts javascript1.0 - javascript1.7, but WinIE 7 accepts only - // javascript1.1 - javascript1.3. - // Mozilla 1.8 and WinIE 7 both accept javascript and livescript. - // WinIE 7 accepts ecmascript and jscript, but Mozilla 1.8 doesn't. - // Neither Mozilla 1.8 nor WinIE 7 accept leading or trailing whitespace. - // We want to accept all the values that either of these browsers accept, but - // not other values. - - // FIXME: This function is not HTML5 compliant. These belong in the MIME - // registry as "text/javascript<version>" entries. - return EqualIgnoringASCIICase(language, "javascript") || - EqualIgnoringASCIICase(language, "javascript1.0") || - EqualIgnoringASCIICase(language, "javascript1.1") || - EqualIgnoringASCIICase(language, "javascript1.2") || - EqualIgnoringASCIICase(language, "javascript1.3") || - EqualIgnoringASCIICase(language, "javascript1.4") || - EqualIgnoringASCIICase(language, "javascript1.5") || - EqualIgnoringASCIICase(language, "javascript1.6") || - EqualIgnoringASCIICase(language, "javascript1.7") || - EqualIgnoringASCIICase(language, "livescript") || - EqualIgnoringASCIICase(language, "ecmascript") || - EqualIgnoringASCIICase(language, "jscript"); -} - bool MIMETypeRegistry::IsSupportedNonImageMIMEType(const String& mime_type) { return blink::IsSupportedNonImageMimeType(ToLowerASCIIOrEmpty(mime_type)); }
diff --git a/third_party/blink/renderer/platform/network/mime/mime_type_registry.h b/third_party/blink/renderer/platform/network/mime/mime_type_registry.h index 17a8bdb..a0f1329c8 100644 --- a/third_party/blink/renderer/platform/network/mime/mime_type_registry.h +++ b/third_party/blink/renderer/platform/network/mime/mime_type_registry.h
@@ -70,8 +70,6 @@ // https://mimesniff.spec.whatwg.org/#json-mime-type static bool IsJSONMimeType(const String& mime_type); - static bool IsLegacySupportedJavaScriptLanguage(const String& language); - // Checks to see if a non-image mime type is suitable for being loaded as a // document in a frame. Includes supported JavaScript MIME types. static bool IsSupportedNonImageMIMEType(const String& mime_type);
diff --git a/third_party/blink/web_tests/platform/generic/webexposed/feature-policy-features-expected.txt b/third_party/blink/web_tests/platform/generic/webexposed/feature-policy-features-expected.txt index 4c2b438..8c63591 100644 --- a/third_party/blink/web_tests/platform/generic/webexposed/feature-policy-features-expected.txt +++ b/third_party/blink/web_tests/platform/generic/webexposed/feature-policy-features-expected.txt
@@ -10,7 +10,6 @@ ch-downlink ch-dpr ch-ect -ch-partitioned-cookies ch-prefers-color-scheme ch-rtt ch-save-data
diff --git a/third_party/ipcz/include/ipcz/ipcz.h b/third_party/ipcz/include/ipcz/ipcz.h index 1a98dfca..d363d388 100644 --- a/third_party/ipcz/include/ipcz/ipcz.h +++ b/third_party/ipcz/include/ipcz/ipcz.h
@@ -231,13 +231,13 @@ // // `options` is currently unused and must be null. typedef IpczResult(IPCZ_API* IpczTransportActivityHandler)( - IpczHandle transport, - const void* data, - size_t num_bytes, - const IpczDriverHandle* driver_handles, - size_t num_driver_handles, - IpczTransportActivityFlags flags, - const void* options); + IpczHandle transport, // in + const void* data, // in + size_t num_bytes, // in + const IpczDriverHandle* driver_handles, // in + size_t num_driver_handles, // in + IpczTransportActivityFlags flags, // in + const void* options); // in // Structure to be filled in by a driver's GetSharedMemoryInfo(). struct IPCZ_ALIGN(8) IpczSharedMemoryInfo { @@ -266,9 +266,9 @@ // Called by ipcz to request that the driver release the object identified by // `handle`. - IpczResult(IPCZ_API* Close)(IpczDriverHandle handle, - uint32_t flags, - const void* options); + IpczResult(IPCZ_API* Close)(IpczDriverHandle handle, // in + uint32_t flags, // in + const void* options); // in // Serializes a driver object identified by `handle` into a collection of // bytes and readily transmissible driver objects, for eventual transmission @@ -317,14 +317,14 @@ // return IPCZ_RESULT_OK. In this case ipcz relinquishes `handle` and will no // longer refer to it unless the driver outputs it back in `handles`, implying // that it was already transmissible as-is. - IpczResult(IPCZ_API* Serialize)(IpczDriverHandle handle, - IpczDriverHandle transport, - uint32_t flags, - const void* options, - void* data, - size_t* num_bytes, - IpczDriverHandle* handles, - size_t* num_handles); + IpczResult(IPCZ_API* Serialize)(IpczDriverHandle handle, // in + IpczDriverHandle transport, // in + uint32_t flags, // in + const void* options, // in + void* data, // out + size_t* num_bytes, // in/out + IpczDriverHandle* handles, // out + size_t* num_handles); // in/out // Deserializes a driver object from a collection of bytes and transmissible // driver objects which which was originally produced by Serialize() and @@ -333,14 +333,15 @@ // Any return value other than IPCZ_RESULT_OK indicates an error and implies // that `handle` is unmodified. Otherwise `handle` must contain a valid driver // handle to the deserialized object. - IpczResult(IPCZ_API* Deserialize)(const void* data, - size_t num_bytes, - const IpczDriverHandle* driver_handles, - size_t num_driver_handles, - IpczDriverHandle transport, - uint32_t flags, - const void* options, - IpczDriverHandle* handle); + IpczResult(IPCZ_API* Deserialize)( + const void* data, // in + size_t num_bytes, // in + const IpczDriverHandle* driver_handles, // in + size_t num_driver_handles, // in + IpczDriverHandle transport, // in + uint32_t flags, // in + const void* options, // in + IpczDriverHandle* handle); // out // Creates a new pair of entangled bidirectional transports, returning them in // `new_transport0` and `new_transport1`. @@ -354,12 +355,13 @@ // // Any return value other than IPCZ_RESULT_OK indicates an error and implies // that `new_transport0` and `new_transport1` are unmodified. - IpczResult(IPCZ_API* CreateTransports)(IpczDriverHandle transport0, - IpczDriverHandle transport1, - uint32_t flags, - const void* options, - IpczDriverHandle* new_transport0, - IpczDriverHandle* new_transport1); + IpczResult(IPCZ_API* CreateTransports)( + IpczDriverHandle transport0, // in + IpczDriverHandle transport1, // in + uint32_t flags, // in + const void* options, // in + IpczDriverHandle* new_transport0, // out + IpczDriverHandle* new_transport1); // out // Called by ipcz to activate a transport. `driver_transport` is the // driver-side handle assigned to the transport by the driver, either as given @@ -387,19 +389,20 @@ // The driver may elicit forced destruction of itself by calling // `activity_handler` with the flag IPCZ_TRANSPORT_ACTIVITY_DEACTIVATED. IpczResult(IPCZ_API* ActivateTransport)( - IpczDriverHandle driver_transport, - IpczHandle transport, - IpczTransportActivityHandler activity_handler, - uint32_t flags, - const void* options); + IpczDriverHandle driver_transport, // in + IpczHandle transport, // in + IpczTransportActivityHandler activity_handler, // in + uint32_t flags, // in + const void* options); // in // Called by ipcz to deactivate a transport. Once this returns successfully, // the driver must make no further calls into this transport's activity // handler. ipcz may continue to use the transport for outgoing transmissions // until the driver's Close() is also called on `driver_transport`. - IpczResult(IPCZ_API* DeactivateTransport)(IpczDriverHandle driver_transport, - uint32_t flags, - const void* options); + IpczResult(IPCZ_API* DeactivateTransport)( + IpczDriverHandle driver_transport, // in + uint32_t flags, // in + const void* options); // in // Called by ipcz to delegate transmission of data and driver handles over the // identified transport endpoint. If the driver cannot fulfill the request, @@ -420,36 +423,38 @@ // // If ipcz only wants to wake the peer node rather than transmit data or // handles, `num_bytes` and `num_driver_handles` may both be zero. - IpczResult(IPCZ_API* Transmit)(IpczDriverHandle driver_transport, - const void* data, - size_t num_bytes, - const IpczDriverHandle* driver_handles, - size_t num_driver_handles, - uint32_t flags, - const void* options); + IpczResult(IPCZ_API* Transmit)(IpczDriverHandle driver_transport, // in + const void* data, // in + size_t num_bytes, // in + const IpczDriverHandle* driver_handles, // in + size_t num_driver_handles, // in + uint32_t flags, // in + const void* options); // in // Allocates a shared memory region and returns a driver handle in // `driver_memory` which can be used to reference it in other calls to the // driver. - IpczResult(IPCZ_API* AllocateSharedMemory)(size_t num_bytes, - uint32_t flags, - const void* options, - IpczDriverHandle* driver_memory); + IpczResult(IPCZ_API* AllocateSharedMemory)( + size_t num_bytes, // in + uint32_t flags, // in + const void* options, // in + IpczDriverHandle* driver_memory); // out // Returns information about the shared memory region identified by // `driver_memory`. - IpczResult(IPCZ_API* GetSharedMemoryInfo)(IpczDriverHandle driver_memory, - uint32_t flags, - const void* options, - struct IpczSharedMemoryInfo* info); + IpczResult(IPCZ_API* GetSharedMemoryInfo)( + IpczDriverHandle driver_memory, // in + uint32_t flags, // in + const void* options, // in + struct IpczSharedMemoryInfo* info); // out // Duplicates a shared memory region handle into a new distinct handle // referencing the same underlying region. IpczResult(IPCZ_API* DuplicateSharedMemory)( - IpczDriverHandle driver_memory, - uint32_t flags, - const void* options, - IpczDriverHandle* new_driver_memory); + IpczDriverHandle driver_memory, // in + uint32_t flags, // in + const void* options, // in + IpczDriverHandle* new_driver_memory); // out // Maps a shared memory region identified by `driver_memory` and returns its // mapped address in `address` on success and a driver handle in @@ -460,17 +465,18 @@ // of `driver_memory`. That is, if `driver_memory` is closed immediately after // this call succeeds, the returned mapping must still remain valid until the // mapping itself is closed. - IpczResult(IPCZ_API* MapSharedMemory)(IpczDriverHandle driver_memory, - uint32_t flags, - const void* options, - void** address, - IpczDriverHandle* driver_mapping); + IpczResult(IPCZ_API* MapSharedMemory)( + IpczDriverHandle driver_memory, // in + uint32_t flags, // in + const void* options, // in + void** address, // out + IpczDriverHandle* driver_mapping); // out // Generates `num_bytes` bytes of random data to fill `buffer`. - IpczResult(IPCZ_API* GenerateRandomBytes)(size_t num_bytes, - uint32_t flags, - const void* options, - void* buffer); + IpczResult(IPCZ_API* GenerateRandomBytes)(size_t num_bytes, // in + uint32_t flags, // in + const void* options, // in + void* buffer); // out }; #if defined(__cplusplus) @@ -824,9 +830,9 @@ // successfully closed by this operation. // // IPCZ_RESULT_INVALID_ARGUMENT if `handle` is invalid. - IpczResult(IPCZ_API* Close)(IpczHandle handle, - uint32_t flags, - const void* options); + IpczResult(IPCZ_API* Close)(IpczHandle handle, // in + uint32_t flags, // in + const void* options); // in // Initializes a new ipcz node. Applications typically need only one node in // each communicating process, but it's OK to create more. Practical use cases @@ -864,11 +870,11 @@ // from operating correctly. For example, the is returned if ipcz was // built against a std::atomic implementation which does not provide // lock-free 32-bit and 64-bit atomics. - IpczResult(IPCZ_API* CreateNode)(const struct IpczDriver* driver, - IpczDriverHandle driver_node, - IpczCreateNodeFlags flags, - const void* options, - IpczHandle* node); + IpczResult(IPCZ_API* CreateNode)(const struct IpczDriver* driver, // in + IpczDriverHandle driver_node, // in + IpczCreateNodeFlags flags, // in + const void* options, // in + IpczHandle* node); // out // Connects `node` to another node in the system using an application-provided // driver transport handle in `driver_transport` for communication. If this @@ -935,12 +941,12 @@ // IPCZ_RESULT_OUT_OF_RANGE if `num_initial_portals` is larger than the // ipcz implementation allows. There is no hard limit specified, but // any ipcz implementation must support at least 8 initial portals. - IpczResult(IPCZ_API* ConnectNode)(IpczHandle node, - IpczDriverHandle driver_transport, - size_t num_initial_portals, - IpczConnectNodeFlags flags, - const void* options, - IpczHandle* initial_portals); + IpczResult(IPCZ_API* ConnectNode)(IpczHandle node, // in + IpczDriverHandle driver_transport, // in + size_t num_initial_portals, // in + IpczConnectNodeFlags flags, // in + const void* options, // in + IpczHandle* initial_portals); // out // Opens two new portals which exist as each other's opposite. // @@ -965,11 +971,11 @@ // // IPCZ_RESULT_INVALID_ARGUMENT if `node` is invalid, or if either // `portal0` or `portal1` is null. - IpczResult(IPCZ_API* OpenPortals)(IpczHandle node, - uint32_t flags, - const void* options, - IpczHandle* portal0, - IpczHandle* portal1); + IpczResult(IPCZ_API* OpenPortals)(IpczHandle node, // in + uint32_t flags, // in + const void* options, // in + IpczHandle* portal0, // out + IpczHandle* portal1); // out // Merges two portals into each other, effectively destroying both while // linking their respective peer portals with each other. A portal cannot @@ -1006,10 +1012,10 @@ // // IPCZ_RESULT_FAILED_PRECONDITION if either `first` or `second` has // already had one or more parcels put into or gotten out of them. - IpczResult(IPCZ_API* MergePortals)(IpczHandle first, - IpczHandle second, - uint32_t flags, - const void* options); + IpczResult(IPCZ_API* MergePortals)(IpczHandle first, // in + IpczHandle second, // in + uint32_t flags, // in + const void* options); // out // Queries specific details regarding the status of a portal, such as the // number of unread parcels or data bytes available on the portal or its @@ -1031,10 +1037,11 @@ // // IPCZ_RESULT_INVALID_ARGUMENT `portal` is invalid. `status` is null or // invalid. - IpczResult(IPCZ_API* QueryPortalStatus)(IpczHandle portal, - uint32_t flags, - const void* options, - struct IpczPortalStatus* status); + IpczResult(IPCZ_API* QueryPortalStatus)( + IpczHandle portal, // in + uint32_t flags, // in + const void* options, // in + struct IpczPortalStatus* status); // out // Puts any combination of data and handles into the portal identified by // `portal`. Everything put into a portal can be retrieved in the same order @@ -1077,13 +1084,13 @@ // // IPCZ_RESULT_NOT_FOUND if it is known that the opposite portal has // already been closed and anything put into this portal would be lost. - IpczResult(IPCZ_API* Put)(IpczHandle portal, - const void* data, - size_t num_bytes, - const IpczHandle* handles, - size_t num_handles, - uint32_t flags, - const struct IpczPutOptions* options); + IpczResult(IPCZ_API* Put)(IpczHandle portal, // in + const void* data, // in + size_t num_bytes, // in + const IpczHandle* handles, // in + size_t num_handles, // in + uint32_t flags, // in + const struct IpczPutOptions* options); // in // Begins a two-phase put operation on `portal`. While a two-phase put // operation is in progress on a portal, any other BeginPut() call on the same @@ -1132,11 +1139,12 @@ // // IPCZ_RESULT_NOT_FOUND if it is known that the opposite portal has // already been closed and anything put into this portal would be lost. - IpczResult(IPCZ_API* BeginPut)(IpczHandle portal, - IpczBeginPutFlags flags, - const struct IpczBeginPutOptions* options, - size_t* num_bytes, - void** data); + IpczResult(IPCZ_API* BeginPut)( + IpczHandle portal, // in + IpczBeginPutFlags flags, // in + const struct IpczBeginPutOptions* options, // in + size_t* num_bytes, // out + void** data); // out // Ends the two-phase put operation started by the most recent successful call // to BeginPut() on `portal`. @@ -1178,12 +1186,12 @@ // // IPCZ_RESULT_NOT_FOUND if it is known that the opposite portal has // already been closed and anything put into this portal would be lost. - IpczResult(IPCZ_API* EndPut)(IpczHandle portal, - size_t num_bytes_produced, - const IpczHandle* handles, - size_t num_handles, - IpczEndPutFlags flags, - const void* options); + IpczResult(IPCZ_API* EndPut)(IpczHandle portal, // in + size_t num_bytes_produced, // in + const IpczHandle* handles, // in + size_t num_handles, // in + IpczEndPutFlags flags, // in + const void* options); // in // Retrieves some combination of data and handles from a portal, as placed by // a prior put operation on the opposite portal. @@ -1244,13 +1252,13 @@ // // IPCZ_RESULT_ALREADY_EXISTS if there is a two-phase get operation in // progress on `portal`. - IpczResult(IPCZ_API* Get)(IpczHandle portal, - IpczGetFlags flags, - const void* options, - void* data, - size_t* num_bytes, - IpczHandle* handles, - size_t* num_handles); + IpczResult(IPCZ_API* Get)(IpczHandle portal, // in + IpczGetFlags flags, // in + const void* options, // in + void* data, // out + size_t* num_bytes, // in/out + IpczHandle* handles, // out + size_t* num_handles); // in/out // Begins a two-phase get operation on `portal` to retrieve data and handles. // While a two-phase get operation is in progress on a portal, all other get @@ -1304,12 +1312,12 @@ // // IPCZ_RESULT_ALREADY_EXISTS if there is already a two-phase get operation // in progress on `portal`. - IpczResult(IPCZ_API* BeginGet)(IpczHandle portal, - uint32_t flags, - const void* options, - const void** data, - size_t* num_bytes, - size_t* num_handles); + IpczResult(IPCZ_API* BeginGet)(IpczHandle portal, // in + uint32_t flags, // in + const void* options, // in + const void** data, // out + size_t* num_bytes, // out + size_t* num_handles); // out // Ends the two-phase get operation started by the most recent successful call // to BeginGet() on `portal`. @@ -1341,12 +1349,12 @@ // // IPCZ_RESULT_FAILED_PRECONDITION if there was no two-phase get operation // in progress on `portal`. - IpczResult(IPCZ_API* EndGet)(IpczHandle portal, - size_t num_bytes_consumed, - size_t num_handles, - IpczEndGetFlags flags, - const void* options, - IpczHandle* handles); + IpczResult(IPCZ_API* EndGet)(IpczHandle portal, // in + size_t num_bytes_consumed, // in + size_t num_handles, // in + IpczEndGetFlags flags, // in + const void* options, // in + IpczHandle* handles); // out // Attempts to install a trap to catch interesting changes to a portal's // state. The condition(s) to observe are specified in `conditions`. @@ -1394,14 +1402,15 @@ // `conditions` which were already satisfied by the portal's state. If // `status` is non-null, a copy of the portal's last known status will // also be stored there. - IpczResult(IPCZ_API* Trap)(IpczHandle portal, - const struct IpczTrapConditions* conditions, - IpczTrapEventHandler handler, - uintptr_t context, - uint32_t flags, - const void* options, - IpczTrapConditionFlags* satisfied_condition_flags, - struct IpczPortalStatus* status); + IpczResult(IPCZ_API* Trap)( + IpczHandle portal, // in + const struct IpczTrapConditions* conditions, // in + IpczTrapEventHandler handler, // in + uintptr_t context, // in + uint32_t flags, // in + const void* options, // in + IpczTrapConditionFlags* satisfied_condition_flags, // out + struct IpczPortalStatus* status); // out // Boxes an object managed by a node's driver and returns a new IpczHandle to // reference the box. If the driver is able to serialize the boxed object, the @@ -1420,11 +1429,11 @@ // returned in `handle`. // // IPCZ_RESULT_INVALID_ARGUMENT if `driver_handle` was invalid. - IpczResult(IPCZ_API* Box)(IpczHandle node, - IpczDriverHandle driver_handle, - uint32_t flags, - const void* options, - IpczHandle* handle); + IpczResult(IPCZ_API* Box)(IpczHandle node, // in + IpczDriverHandle driver_handle, // in + uint32_t flags, // in + const void* options, // in + IpczHandle* handle); // out // Unboxes a driver object from an IpczHandle previously produced by Box(). // @@ -1439,10 +1448,10 @@ // // IPCZ_RESULT_INVALID_ARGUMENT if `handle` is invalid or does not // reference a box. - IpczResult(IPCZ_API* Unbox)(IpczHandle handle, - IpczUnboxFlags flags, - const void* options, - IpczDriverHandle* driver_handle); + IpczResult(IPCZ_API* Unbox)(IpczHandle handle, // in + IpczUnboxFlags flags, // in + const void* options, // in + IpczDriverHandle* driver_handle); // out }; // A function which populates `api` with a table of ipcz API functions. The
diff --git a/third_party/libwebp/README.chromium b/third_party/libwebp/README.chromium index 16fe95c..f01521c 100644 --- a/third_party/libwebp/README.chromium +++ b/third_party/libwebp/README.chromium
@@ -13,7 +13,7 @@ WebP is an image format that does both lossy and lossless compression of digital photographic images. WebP consists of a codec based on VP8, that Google -open-sourced in May 2010 and a container based on RIFF. Webmasters, web +open-sourced in May 2010 and a container based on RIFF. Website managers, web developers and browser developers can use WebP to compress, archive and distribute digital images more efficiently.
diff --git a/third_party/tflite_support/BUILD.gn b/third_party/tflite_support/BUILD.gn index a28c3cc..ae1ca47 100644 --- a/third_party/tflite_support/BUILD.gn +++ b/third_party/tflite_support/BUILD.gn
@@ -121,6 +121,10 @@ "src/tensorflow_lite_support/cc/utils/common_utils.h", "src/tensorflow_lite_support/metadata/cc/metadata_extractor.cc", "src/tensorflow_lite_support/metadata/cc/metadata_extractor.h", + "src/tensorflow_lite_support/metadata/cc/utils/zip_readonly_mem_file.cc", + "src/tensorflow_lite_support/metadata/cc/utils/zip_readonly_mem_file.h", + "src/tensorflow_lite_support/metadata/cc/utils/zip_writable_mem_file.cc", + "src/tensorflow_lite_support/metadata/cc/utils/zip_writable_mem_file.h", ] deps = [ @@ -136,6 +140,7 @@ "//third_party/tflite:tflite-config-proto", "//third_party/tflite:tflite_public_headers", "//third_party/utf", + "//third_party/zlib:minizip", ] public_deps = [
diff --git a/third_party/tflite_support/DEPS b/third_party/tflite_support/DEPS index 31d7297fc..f078d8b8 100644 --- a/third_party/tflite_support/DEPS +++ b/third_party/tflite_support/DEPS
@@ -14,3 +14,7 @@ "+tensorflow_text", "+tflite", ] + +skip_child_includes = [ + "src", +] \ No newline at end of file
diff --git a/third_party/tflite_support/README.chromium b/third_party/tflite_support/README.chromium index 3d5f5dbf..cfaf220 100644 --- a/third_party/tflite_support/README.chromium +++ b/third_party/tflite_support/README.chromium
@@ -1,8 +1,8 @@ Name: TensorFlow Lite Support Short Name: tflite-support URL: https://github.com/tensorflow/tflite-support -Version: cd6760941473a976a78aa33fa199c2058f1441fd -Date: 2022/04/13 +Version: be6820a9a617b57defecbd4c766eb0bab707cac0 +Date: 2022/05/25 License: Apache 2.0 License File: LICENSE Security Critical: Yes @@ -27,16 +27,20 @@ is a no-op in chromium builds and upsets clang. 04) Do not use absl::any since it is not supported in chromium 05) Remove an unneeded static initializer. -06) Run clang-format. +06) Check (instead of resetting) the cancel_flag_ before Invoking the model. +07) Remove support for creating a model handler from a file. +08) Run clang-format. * This patch might not apply cleanly, so run `git cl format` and commit the changes. -07) Check (instead of resetting) the cancel_flag_ before Invoking the model. -08) Remove unbuilt files that cause presubmit errors. -09) Remove support for creating a model handler from a file. +09) Remove unbuilt files that cause `git cl presubmit` errors. + * This patch intentionally does not apply because it was made with + `--irreversible-delete` because it is deleting a large .tflite file causing + the chromium-presubmit bot to fail. Update Process (internal: http://shortn/_nwz8liqimy): 1) Run these commands: ``` + pushd third_party/tflite_support/ rm -rf src/ git clone https://github.com/tensorflow/tflite-support/ @@ -47,5 +51,4 @@ 2) Apply each patch listed above residing in patches/ using `git apply third_party/tflite_support/patches/$PATCHFILE`. 3) Get the build working. -4) Record the patches made with `git format-patches HEAD -<number of changes>` - +4) Record the patches made with `git format-patch HEAD -<number of changes>`
diff --git a/third_party/tflite_support/patches/0001-use-re2-StringPiece-for-RegexTokenizer-Tokenize.patch b/third_party/tflite_support/patches/0001-use-re2-StringPiece-for-RegexTokenizer-Tokenize.patch index 2402702..4993b7a6 100644 --- a/third_party/tflite_support/patches/0001-use-re2-StringPiece-for-RegexTokenizer-Tokenize.patch +++ b/third_party/tflite_support/patches/0001-use-re2-StringPiece-for-RegexTokenizer-Tokenize.patch
@@ -1,6 +1,6 @@ -From 88b2bdd174950afc9e01bb902493e38d8ef2bf66 Mon Sep 17 00:00:00 2001 +From e19df25a06cb62b9e49b937c17d391d3b90bb3aa Mon Sep 17 00:00:00 2001 From: Robert Ogden <robertogden@chromium.org> -Date: Wed, 13 Apr 2022 10:56:01 -0700 +Date: Wed, 25 May 2022 10:52:32 -0700 Subject: [PATCH 1/9] use re2 StringPiece for RegexTokenizer Tokenize --- @@ -32,5 +32,5 @@ bool has_non_empty_token = token.length() > 0; -- -2.35.1.1178.g4f1659d476-goog +2.36.1.124.g0e6072fb45-goog
diff --git a/third_party/tflite_support/patches/0002-sentencepiece-tokenization-not-supported.patch b/third_party/tflite_support/patches/0002-sentencepiece-tokenization-not-supported.patch index 38bdf94..b14c6c7 100644 --- a/third_party/tflite_support/patches/0002-sentencepiece-tokenization-not-supported.patch +++ b/third_party/tflite_support/patches/0002-sentencepiece-tokenization-not-supported.patch
@@ -1,6 +1,6 @@ -From 9d4affd253d393f862c1d084207a930b692afd93 Mon Sep 17 00:00:00 2001 +From bd41a985345f17306e472f7d825d30b3e3d0baba Mon Sep 17 00:00:00 2001 From: Robert Ogden <robertogden@chromium.org> -Date: Wed, 13 Apr 2022 10:56:24 -0700 +Date: Wed, 25 May 2022 10:52:49 -0700 Subject: [PATCH 2/9] sentencepiece tokenization not supported --- @@ -47,5 +47,5 @@ case ProcessUnitOptions_RegexTokenizerOptions: { const tflite::RegexTokenizerOptions* options = -- -2.35.1.1178.g4f1659d476-goog +2.36.1.124.g0e6072fb45-goog
diff --git a/third_party/tflite_support/patches/0003-rm-noop-deprecated-attribute.patch b/third_party/tflite_support/patches/0003-rm-noop-deprecated-attribute.patch index f1ef6c97..fbc6659 100644 --- a/third_party/tflite_support/patches/0003-rm-noop-deprecated-attribute.patch +++ b/third_party/tflite_support/patches/0003-rm-noop-deprecated-attribute.patch
@@ -1,6 +1,6 @@ -From 758683d39012410890eb4599e1df5d315e76361e Mon Sep 17 00:00:00 2001 +From 9e135eeb8ce9afb6919588bfa0cad0e8b9c77c92 Mon Sep 17 00:00:00 2001 From: Robert Ogden <robertogden@chromium.org> -Date: Wed, 13 Apr 2022 10:56:42 -0700 +Date: Wed, 25 May 2022 10:53:14 -0700 Subject: [PATCH 3/9] rm noop deprecated attribute --- @@ -22,5 +22,5 @@ int input_tensor_index = 0; int output_score_tensor_index = 0; -- -2.35.1.1178.g4f1659d476-goog +2.36.1.124.g0e6072fb45-goog
diff --git a/third_party/tflite_support/patches/0004-do-not-use-absl-any.patch b/third_party/tflite_support/patches/0004-do-not-use-absl-any.patch index 318a41c6..577d58d 100644 --- a/third_party/tflite_support/patches/0004-do-not-use-absl-any.patch +++ b/third_party/tflite_support/patches/0004-do-not-use-absl-any.patch
@@ -1,6 +1,6 @@ -From 84e07f913af87bc2fc6d280a84f0310d7e15ba78 Mon Sep 17 00:00:00 2001 +From c7a885174c0489e6a6a819b4bae23311226e43ee Mon Sep 17 00:00:00 2001 From: Robert Ogden <robertogden@chromium.org> -Date: Wed, 13 Apr 2022 10:57:08 -0700 +Date: Wed, 25 May 2022 10:53:34 -0700 Subject: [PATCH 4/9] do not use absl any --- @@ -60,5 +60,5 @@ Format format_; Orientation orientation_; -- -2.35.1.1178.g4f1659d476-goog +2.36.1.124.g0e6072fb45-goog
diff --git a/third_party/tflite_support/patches/0005-rm-stdio-static-init.patch b/third_party/tflite_support/patches/0005-rm-stdio-static-init.patch index 77b1d64..ef83104 100644 --- a/third_party/tflite_support/patches/0005-rm-stdio-static-init.patch +++ b/third_party/tflite_support/patches/0005-rm-stdio-static-init.patch
@@ -1,6 +1,6 @@ -From 634325b170300378eafdbbdb54f32e8d7b8b45ab Mon Sep 17 00:00:00 2001 +From fc236446d266f3fe7002bb2aa5d4b50e5241dcab Mon Sep 17 00:00:00 2001 From: Robert Ogden <robertogden@chromium.org> -Date: Wed, 13 Apr 2022 10:59:05 -0700 +Date: Wed, 25 May 2022 10:53:56 -0700 Subject: [PATCH 5/9] rm stdio static init --- @@ -34,5 +34,5 @@ using ::tflite::proto::ComputeSettings; using ::tflite::support::CreateStatusWithPayload; -- -2.35.1.1178.g4f1659d476-goog +2.36.1.124.g0e6072fb45-goog
diff --git a/third_party/tflite_support/patches/0007-check-cancel-flag-before-calling-invoke.patch b/third_party/tflite_support/patches/0006-check-cancel-flag-before-calling-invoke.patch similarity index 91% rename from third_party/tflite_support/patches/0007-check-cancel-flag-before-calling-invoke.patch rename to third_party/tflite_support/patches/0006-check-cancel-flag-before-calling-invoke.patch index e54588c..e20d5a0 100644 --- a/third_party/tflite_support/patches/0007-check-cancel-flag-before-calling-invoke.patch +++ b/third_party/tflite_support/patches/0006-check-cancel-flag-before-calling-invoke.patch
@@ -1,14 +1,14 @@ -From e45f2fd967865982119a15ce4c8f4f220cc1538f Mon Sep 17 00:00:00 2001 +From 98f819d7d88b6f03b3bbab2d116d2fa31674a154 Mon Sep 17 00:00:00 2001 From: Robert Ogden <robertogden@chromium.org> -Date: Wed, 13 Apr 2022 11:09:45 -0700 -Subject: [PATCH 7/9] check cancel flag before calling invoke +Date: Wed, 25 May 2022 10:54:30 -0700 +Subject: [PATCH 6/9] check cancel flag before calling invoke --- .../cc/port/default/tflite_wrapper.cc | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.cc -index bb43d09f4a96b..4d23efe43bc99 100644 +index d47c1ce7e5179..11f9d584cfdd0 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.cc @@ -258,8 +258,10 @@ absl::Status TfLiteInterpreterWrapper::InvokeWithFallback( @@ -53,5 +53,5 @@ } return absl::InternalError("Invoke() failed."); -- -2.35.1.1178.g4f1659d476-goog +2.36.1.124.g0e6072fb45-goog
diff --git a/third_party/tflite_support/patches/0009-remove-support-for-creating-a-model-handler-from-a-f.patch b/third_party/tflite_support/patches/0007-remove-support-for-creating-a-model-handler-from-a-f.patch similarity index 87% rename from third_party/tflite_support/patches/0009-remove-support-for-creating-a-model-handler-from-a-f.patch rename to third_party/tflite_support/patches/0007-remove-support-for-creating-a-model-handler-from-a-f.patch index e5824d8..529debc6 100644 --- a/third_party/tflite_support/patches/0009-remove-support-for-creating-a-model-handler-from-a-f.patch +++ b/third_party/tflite_support/patches/0007-remove-support-for-creating-a-model-handler-from-a-f.patch
@@ -1,17 +1,30 @@ -From b83a7480d906936e2920a17027d1a76bb4445673 Mon Sep 17 00:00:00 2001 +From 9488411c0779c22ba93fb03994e90cda25b65bd0 Mon Sep 17 00:00:00 2001 From: Robert Ogden <robertogden@chromium.org> -Date: Wed, 13 Apr 2022 11:51:51 -0700 -Subject: [PATCH 9/9] remove support for creating a model handler from a file +Date: Wed, 25 May 2022 11:03:30 -0700 +Subject: [PATCH 7/9] remove support for creating a model handler from a file --- - .../cc/task/core/external_file_handler.cc | 136 +----------------- - .../cc/task/core/external_file_handler.h | 21 --- + third_party/tflite_support/README.chromium | 1 + + .../cc/task/core/external_file_handler.cc | 137 +----------------- + .../cc/task/core/external_file_handler.h | 20 --- .../cc/task/core/tflite_engine.cc | 2 - .../cc/task/core/tflite_engine.h | 2 - - 4 files changed, 6 insertions(+), 155 deletions(-) + 5 files changed, 8 insertions(+), 154 deletions(-) +diff --git a/third_party/tflite_support/README.chromium b/third_party/tflite_support/README.chromium +index 3d5f5dbf2edf9..c1da2b4e73efe 100644 +--- a/third_party/tflite_support/README.chromium ++++ b/third_party/tflite_support/README.chromium +@@ -37,6 +37,7 @@ is a no-op in chromium builds and upsets clang. + Update Process (internal: http://shortn/_nwz8liqimy): + 1) Run these commands: + ``` ++ + pushd third_party/tflite_support/ + rm -rf src/ + git clone https://github.com/tensorflow/tflite-support/ diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/external_file_handler.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/external_file_handler.cc -index b9ae32253cb29..e15830d5ab061 100644 +index 5e17e14dc5f7a..9c4cc2009baea 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/external_file_handler.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/external_file_handler.cc @@ -15,12 +15,6 @@ limitations under the License. @@ -46,7 +59,7 @@ } // namespace /* static */ -@@ -71,123 +53,17 @@ absl::Status ExternalFileHandler::MapExternalFile() { +@@ -71,123 +53,18 @@ absl::Status ExternalFileHandler::MapExternalFile() { if (!external_file_.file_content().empty()) { return absl::OkStatus(); } @@ -151,6 +164,7 @@ + StatusCode::kInvalidArgument, + "ExternalFile must specify 'file_content' in Chromium.", + TfLiteSupportStatus::kInvalidArgumentError); ++ } absl::string_view ExternalFileHandler::GetFileContent() { @@ -161,7 +175,7 @@ - buffer_offset_ - buffer_aligned_offset_, - buffer_size_); - } -+ return external_file_.file_content(); ++ return external_file_.file_content(); } -ExternalFileHandler::~ExternalFileHandler() { @@ -177,14 +191,13 @@ } // namespace core } // namespace task diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/external_file_handler.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/external_file_handler.h -index 0b74e468d004f..9f35fdd6d09ce 100644 +index e8b6831c6ad69..a7daa175f77f5 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/external_file_handler.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/external_file_handler.h -@@ -64,27 +64,6 @@ class ExternalFileHandler { - +@@ -65,26 +65,6 @@ class ExternalFileHandler { // Reference to the input ExternalFile. const ExternalFile& external_file_; -- + - // The file descriptor of the ExternalFile if provided by path, as it is - // opened and owned by this class. Set to -1 otherwise. - int owned_fd_{-1}; @@ -209,7 +222,7 @@ } // namespace core diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.cc -index 2794290a2411e..41e06389af80b 100644 +index e0f69cd1c80ac..5999090cab973 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.cc @@ -15,8 +15,6 @@ limitations under the License. @@ -220,9 +233,9 @@ - #include <memory> - #include "absl/strings/match.h" // from @com_google_absl + #include "absl/strings/match.h" // from @com_google_absl diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.h -index 1c6a067d6be9e..0cbaa738e6db6 100644 +index 9b44c6e5c022a..53dabdc4841d7 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.h @@ -16,8 +16,6 @@ limitations under the License. @@ -233,7 +246,7 @@ - #include <memory> - #include "absl/memory/memory.h" // from @com_google_absl + #include "absl/memory/memory.h" // from @com_google_absl -- -2.35.1.1178.g4f1659d476-goog +2.36.1.124.g0e6072fb45-goog
diff --git a/third_party/tflite_support/patches/0008-remove-unbuilt-files-with-presubmit-errors.patch b/third_party/tflite_support/patches/0008-remove-unbuilt-files-with-presubmit-errors.patch deleted file mode 100644 index a398cd6..0000000 --- a/third_party/tflite_support/patches/0008-remove-unbuilt-files-with-presubmit-errors.patch +++ /dev/null
@@ -1,4795 +0,0 @@ -From b545fb0077fc975480b8504b805464a41c6ae30f Mon Sep 17 00:00:00 2001 -From: Robert Ogden <robertogden@chromium.org> -Date: Wed, 13 Apr 2022 11:15:48 -0700 -Subject: [PATCH 8/9] remove unbuilt files with presubmit errors - ---- - .../cc/port/benchmark.h | 21 - - .../cc/port/default/status_matchers.h | 55 -- - .../tensorflow_lite_support/cc/port/gmock.h | 21 - - .../tensorflow_lite_support/cc/port/gtest.h | 21 - - .../tensorflow_lite_support/cc/port/proto2.h | 32 -- - .../cc/task/processor/search_postprocessor.cc | 362 ------------ - .../cc/task/processor/search_postprocessor.h | 112 ---- - .../ios/utils/Sources/TFLStringUtil.mm | 26 - - .../metadata/cc/metadata_populator.cc | 150 ----- - .../metadata/cc/utils/zip_mem_file.cc | 134 ----- - .../metadata/cc/utils/zip_mem_file.h | 75 --- - .../audio/pybinds/_pywrap_audio_classifier.cc | 84 --- - .../audio/pybinds/_pywrap_audio_embedder.cc | 78 --- - .../text/pybinds/_pywrap_text_embedder.cc | 68 --- - .../task/vision/core/pybinds/image_utils.cc | 68 --- - .../pybinds/_pywrap_image_classifier.cc | 108 ---- - .../vision/pybinds/_pywrap_image_embedder.cc | 145 ----- - .../vision/pybinds/_pywrap_image_segmenter.cc | 73 --- - .../vision/pybinds/_pywrap_object_detector.cc | 84 --- - .../scann_ondevice/cc/core/index_table_sum.h | 256 --------- - .../scann_ondevice/cc/core/partitioner.h | 76 --- - .../scann_ondevice/cc/core/processor.h | 101 ---- - .../scann_ondevice/cc/core/searcher.h | 256 --------- - .../scann_ondevice/cc/core/searcher_test.cc | 532 ------------------ - .../scann_ondevice/cc/core/simd_utils.h | 303 ---------- - .../scann_ondevice/cc/index.cc | 138 ----- - .../scann_ondevice/cc/index.h | 91 --- - .../scann_ondevice/cc/index_builder.cc | 177 ------ - .../scann_ondevice/cc/index_builder.h | 68 --- - .../cc/mem_random_access_file.cc | 52 -- - .../cc/mem_random_access_file.h | 61 -- - .../scann_ondevice/cc/mem_writable_file.h | 64 --- - .../cc/python/index_builder_py_wrapper.cc | 64 --- - .../cc/test/index_builder_test.cc | 363 ------------ - .../cc/test/mem_random_access_file_test.cc | 64 --- - .../leveldb_testing_utils_py_wrapper.cc | 76 --- - .../tools/ci_build/common_win.bat | 29 - - 37 files changed, 4488 deletions(-) - delete mode 100644 third_party/tflite_support/src/tensorflow_lite_support/cc/port/benchmark.h - delete mode 100644 third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/status_matchers.h - delete mode 100644 third_party/tflite_support/src/tensorflow_lite_support/cc/port/gmock.h - delete mode 100644 third_party/tflite_support/src/tensorflow_lite_support/cc/port/gtest.h - delete mode 100644 third_party/tflite_support/src/tensorflow_lite_support/cc/port/proto2.h - delete mode 100644 third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/search_postprocessor.cc - delete mode 100644 third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/search_postprocessor.h - delete mode 100644 third_party/tflite_support/src/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.mm - delete mode 100644 third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_populator.cc - delete mode 100644 third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_mem_file.cc - delete mode 100644 third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_mem_file.h - delete mode 100644 third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/pybinds/_pywrap_audio_classifier.cc - delete mode 100644 third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/pybinds/_pywrap_audio_embedder.cc - delete mode 100644 third_party/tflite_support/src/tensorflow_lite_support/python/task/text/pybinds/_pywrap_text_embedder.cc - delete mode 100644 third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/core/pybinds/image_utils.cc - delete mode 100644 third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_image_classifier.cc - delete mode 100644 third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_image_embedder.cc - delete mode 100644 third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_image_segmenter.cc - delete mode 100644 third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_object_detector.cc - delete mode 100644 third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/index_table_sum.h - delete mode 100644 third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/partitioner.h - delete mode 100644 third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/processor.h - delete mode 100644 third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/searcher.h - delete mode 100644 third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/searcher_test.cc - delete mode 100644 third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/simd_utils.h - delete mode 100644 third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index.cc - delete mode 100644 third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index.h - delete mode 100644 third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index_builder.cc - delete mode 100644 third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index_builder.h - delete mode 100644 third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/mem_random_access_file.cc - delete mode 100644 third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/mem_random_access_file.h - delete mode 100644 third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/mem_writable_file.h - delete mode 100644 third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/python/index_builder_py_wrapper.cc - delete mode 100644 third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/index_builder_test.cc - delete mode 100644 third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/mem_random_access_file_test.cc - delete mode 100644 third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/python/leveldb_testing_utils_py_wrapper.cc - delete mode 100644 third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/common_win.bat - -diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/benchmark.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/benchmark.h -deleted file mode 100644 -index 74bc1a6857664..0000000000000 ---- a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/benchmark.h -+++ /dev/null -@@ -1,21 +0,0 @@ --/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. -- --Licensed under the Apache License, Version 2.0 (the "License"); --you may not use this file except in compliance with the License. --You may obtain a copy of the License at -- -- http://www.apache.org/licenses/LICENSE-2.0 -- --Unless required by applicable law or agreed to in writing, software --distributed under the License is distributed on an "AS IS" BASIS, --WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. --See the License for the specific language governing permissions and --limitations under the License. --==============================================================================*/ -- --#ifndef TENSORFLOW_LITE_SUPPORT_CC_PORT_BENCHMARK_H_ --#define TENSORFLOW_LITE_SUPPORT_CC_PORT_BENCHMARK_H_ -- --#include "gtest/benchmark.h" -- --#endif // TENSORFLOW_LITE_SUPPORT_CC_PORT_BENCHMARK_H_ -diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/status_matchers.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/status_matchers.h -deleted file mode 100644 -index 6d9668043c183..0000000000000 ---- a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/status_matchers.h -+++ /dev/null -@@ -1,55 +0,0 @@ --/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. -- --Licensed under the Apache License, Version 2.0 (the "License"); --you may not use this file except in compliance with the License. --You may obtain a copy of the License at -- -- http://www.apache.org/licenses/LICENSE-2.0 -- --Unless required by applicable law or agreed to in writing, software --distributed under the License is distributed on an "AS IS" BASIS, --WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. --See the License for the specific language governing permissions and --limitations under the License. --==============================================================================*/ -- --#ifndef TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_STATUS_MATCHERS_H_ --#define TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_STATUS_MATCHERS_H_ -- --#include "gmock/gmock.h" --#include "gtest/gtest.h" -- --#define SUPPORT_STATUS_MACROS_IMPL_CONCAT_INNER_(x, y) x##y --#define SUPPORT_STATUS_MACROS_IMPL_CONCAT_(x, y) \ -- SUPPORT_STATUS_MACROS_IMPL_CONCAT_INNER_(x, y) -- --#undef SUPPORT_ASSERT_OK --#define SUPPORT_ASSERT_OK(expr) \ -- SUPPORT_ASSERT_OK_IMPL_( \ -- SUPPORT_STATUS_MACROS_IMPL_CONCAT_(_status, __LINE__), expr) -- --#define SUPPORT_ASSERT_OK_IMPL_(status, expr) \ -- auto status = (expr); \ -- ASSERT_TRUE(status.ok()); -- --#undef SUPPORT_EXPECT_OK --#define SUPPORT_EXPECT_OK(expr) \ -- SUPPORT_EXPECT_OK_IMPL_( \ -- SUPPORT_STATUS_MACROS_IMPL_CONCAT_(_status, __LINE__), expr) -- --#define SUPPORT_EXPECT_OK_IMPL_(status, expr) \ -- auto status = (expr); \ -- EXPECT_TRUE(status.ok()); -- --#undef SUPPORT_ASSERT_OK_AND_ASSIGN --#define SUPPORT_ASSERT_OK_AND_ASSIGN(lhs, rexpr) \ -- SUPPORT_ASSERT_OK_AND_ASSIGN_IMPL_( \ -- SUPPORT_STATUS_MACROS_IMPL_CONCAT_(_status_or_value, __LINE__), lhs, \ -- rexpr) -- --#define SUPPORT_ASSERT_OK_AND_ASSIGN_IMPL_(statusor, lhs, rexpr) \ -- auto statusor = (rexpr); \ -- ASSERT_TRUE(statusor.ok()); \ -- lhs = std::move(statusor.value()) -- --#endif // TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_STATUS_MATCHERS_H_ -diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/gmock.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/gmock.h -deleted file mode 100644 -index 5e4334db323d6..0000000000000 ---- a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/gmock.h -+++ /dev/null -@@ -1,21 +0,0 @@ --/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. -- --Licensed under the Apache License, Version 2.0 (the "License"); --you may not use this file except in compliance with the License. --You may obtain a copy of the License at -- -- http://www.apache.org/licenses/LICENSE-2.0 -- --Unless required by applicable law or agreed to in writing, software --distributed under the License is distributed on an "AS IS" BASIS, --WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. --See the License for the specific language governing permissions and --limitations under the License. --==============================================================================*/ -- --#ifndef TENSORFLOW_LITE_SUPPORT_CC_PORT_GMOCK_H_ --#define TENSORFLOW_LITE_SUPPORT_CC_PORT_GMOCK_H_ -- --#include "gmock/gmock.h" -- --#endif // TENSORFLOW_LITE_SUPPORT_CC_PORT_GMOCK_H_ -diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/gtest.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/gtest.h -deleted file mode 100644 -index dbe2e5e6f9d7c..0000000000000 ---- a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/gtest.h -+++ /dev/null -@@ -1,21 +0,0 @@ --/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. -- --Licensed under the Apache License, Version 2.0 (the "License"); --you may not use this file except in compliance with the License. --You may obtain a copy of the License at -- -- http://www.apache.org/licenses/LICENSE-2.0 -- --Unless required by applicable law or agreed to in writing, software --distributed under the License is distributed on an "AS IS" BASIS, --WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. --See the License for the specific language governing permissions and --limitations under the License. --==============================================================================*/ -- --#ifndef TENSORFLOW_LITE_SUPPORT_CC_PORT_GTEST_H_ --#define TENSORFLOW_LITE_SUPPORT_CC_PORT_GTEST_H_ -- --#include "gtest/gtest.h" -- --#endif // TENSORFLOW_LITE_SUPPORT_CC_PORT_GTEST_H_ -diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/proto2.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/proto2.h -deleted file mode 100644 -index 3cde2ab81d6ee..0000000000000 ---- a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/proto2.h -+++ /dev/null -@@ -1,32 +0,0 @@ --/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. -- --Licensed under the Apache License, Version 2.0 (the "License"); --you may not use this file except in compliance with the License. --You may obtain a copy of the License at -- -- http://www.apache.org/licenses/LICENSE-2.0 -- --Unless required by applicable law or agreed to in writing, software --distributed under the License is distributed on an "AS IS" BASIS, --WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. --See the License for the specific language governing permissions and --limitations under the License. --==============================================================================*/ --#ifndef TENSORFLOW_LITE_SUPPORT_CC_PORT_PROTO_NS_H_ --#define TENSORFLOW_LITE_SUPPORT_CC_PORT_PROTO_NS_H_ -- --#include "google/protobuf/message_lite.h" --#include "google/protobuf/text_format.h" -- --namespace tflite { --namespace support { --namespace proto { -- --using TextFormat = ::google::protobuf::TextFormat; --using MessageLite = ::google::protobuf::MessageLite; -- --} // namespace proto --} // namespace support --} // namespace tflite -- --#endif // TENSORFLOW_LITE_SUPPORT_CC_PORT_PROTO_NS_H_ -diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/search_postprocessor.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/search_postprocessor.cc -deleted file mode 100644 -index e3bc2688caf3a..0000000000000 ---- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/search_postprocessor.cc -+++ /dev/null -@@ -1,362 +0,0 @@ --/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. -- --Licensed under the Apache License, Version 2.0 (the "License"); --you may not use this file except in compliance with the License. --You may obtain a copy of the License at -- -- http://www.apache.org/licenses/LICENSE-2.0 -- --Unless required by applicable law or agreed to in writing, software --distributed under the License is distributed on an "AS IS" BASIS, --WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. --See the License for the specific language governing permissions and --limitations under the License. --==============================================================================*/ -- --#include "tensorflow_lite_support/cc/task/processor/search_postprocessor.h" -- --#include <algorithm> --#include <cstdint> --#include <initializer_list> --#include <limits> --#include <memory> --#include <vector> -- --#include "Eigen/Core" // from @eigen --#include "absl/memory/memory.h" // from @com_google_absl --#include "absl/status/status.h" // from @com_google_absl --#include "absl/strings/str_format.h" // from @com_google_absl --#include "absl/types/span.h" // from @com_google_absl --#include "tensorflow_lite_support/cc/common.h" --#include "tensorflow_lite_support/cc/port/status_macros.h" --#include "tensorflow_lite_support/cc/port/statusor.h" --#include "tensorflow_lite_support/cc/task/core/external_file_handler.h" --#include "tensorflow_lite_support/cc/task/core/tflite_engine.h" --#include "tensorflow_lite_support/cc/task/processor/embedding_postprocessor.h" --#include "tensorflow_lite_support/cc/task/processor/proto/embedding.pb.h" --#include "tensorflow_lite_support/cc/task/processor/proto/embedding_options.pb.h" --#include "tensorflow_lite_support/cc/task/processor/proto/search_options.pb.h" --#include "tensorflow_lite_support/cc/task/processor/proto/search_result.pb.h" --#include "tensorflow_lite_support/scann_ondevice/cc/core/partitioner.h" --#include "tensorflow_lite_support/scann_ondevice/cc/core/processor.h" --#include "tensorflow_lite_support/scann_ondevice/cc/core/searcher.h" --#include "tensorflow_lite_support/scann_ondevice/cc/core/serialized_searcher.pb.h" --#include "tensorflow_lite_support/scann_ondevice/cc/core/top_n_amortized_constant.h" --#include "tensorflow_lite_support/scann_ondevice/cc/index.h" --#include "tensorflow_lite_support/scann_ondevice/proto/index_config.pb.h" -- --namespace tflite { --namespace task { --namespace processor { -- --namespace { -- --constexpr int kNoNeighborId = -1; -- --using ::tflite::scann_ondevice::Index; --using ::tflite::scann_ondevice::IndexConfig; --using ::tflite::scann_ondevice::core::AsymmetricHashFindNeighbors; --using ::tflite::scann_ondevice::core::DistanceMeasure; --using ::tflite::scann_ondevice::core::FloatFindNeighbors; --using ::tflite::scann_ondevice::core::QueryInfo; --using ::tflite::scann_ondevice::core::ScannOnDeviceConfig; --using ::tflite::scann_ondevice::core::TopN; --using ::tflite::support::CreateStatusWithPayload; --using ::tflite::support::StatusOr; --using ::tflite::support::TfLiteSupportStatus; --using ::tflite::task::core::ExternalFileHandler; --using ::tflite::task::core::TfLiteEngine; --using ::tflite::task::processor::Embedding; -- --using Matrix8u = -- Eigen::Matrix<uint8_t, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>; -- --absl::StatusOr<std::unique_ptr<EmbeddingPostprocessor>> --CreateEmbeddingPostprocessor(TfLiteEngine* engine, -- const std::initializer_list<int> output_indices, -- std::unique_ptr<EmbeddingOptions> options) { -- if (options->quantize()) { -- // ScaNN only supports searching from float embeddings. -- return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument, -- "Setting EmbeddingOptions.normalize = true " -- "is not allowed in searchers.", -- TfLiteSupportStatus::kInvalidArgumentError); -- } -- return EmbeddingPostprocessor::Create(engine, output_indices, -- std::move(options)); --} -- --absl::Status SanityCheckOptions(const SearchOptions& options) { -- if (options.num_results() < 1) { -- return CreateStatusWithPayload( -- absl::StatusCode::kInvalidArgument, -- absl::StrFormat("SearchOptions.num_results must be > 0, found %d.", -- options.num_results()), -- TfLiteSupportStatus::kInvalidArgumentError); -- } -- return absl::OkStatus(); --} -- --absl::Status SanityCheckIndexConfig(const IndexConfig& config) { -- switch (config.embedding_type()) { -- case IndexConfig::UNSPECIFIED: -- return CreateStatusWithPayload( -- absl::StatusCode::kInvalidArgument, -- "Invalid IndexConfig: embedding_type must not be left UNSPECIFIED.", -- TfLiteSupportStatus::kInvalidArgumentError); -- case IndexConfig::FLOAT: -- if (config.scann_config().has_indexer()) { -- return CreateStatusWithPayload( -- absl::StatusCode::kInvalidArgument, -- "Invalid IndexConfig: embedding_type is set to FLOAT but ScaNN " -- "config specifies a product quantization codebook.", -- TfLiteSupportStatus::kInvalidArgumentError); -- } -- break; -- case IndexConfig::UINT8: -- if (!config.scann_config().has_indexer()) { -- return CreateStatusWithPayload( -- absl::StatusCode::kInvalidArgument, -- "Invalid IndexConfig: embedding_type is set to UINT8 but ScaNN " -- "config doesn't specify a product quantization codebook.", -- TfLiteSupportStatus::kInvalidArgumentError); -- } -- break; -- default: -- return CreateStatusWithPayload( -- absl::StatusCode::kInternal, -- "Invalid IndexConfig: unexpected value for embedding_type.", -- TfLiteSupportStatus::kError); -- } -- return absl::OkStatus(); --} -- --absl::StatusOr<DistanceMeasure> GetDistanceMeasure( -- const ScannOnDeviceConfig& config) { -- DistanceMeasure measure = config.query_distance(); -- if (measure == tflite::scann_ondevice::core::UNSPECIFIED) { -- if (config.has_indexer() && config.indexer().has_asymmetric_hashing()) { -- measure = config.indexer().asymmetric_hashing().query_distance(); -- } else if (config.has_partitioner()) { -- measure = config.partitioner().query_distance(); -- } else { -- return CreateStatusWithPayload( -- absl::StatusCode::kInvalidArgument, -- "ScaNN config does not provide mandatory DistanceMeasure.", -- TfLiteSupportStatus::kInvalidArgumentError); -- } -- -- if (measure == tflite::scann_ondevice::core::UNSPECIFIED) { -- return CreateStatusWithPayload( -- absl::StatusCode::kInvalidArgument, -- "UNSPECIFIED is not a valid value for ScaNN config DistanceMeasure.", -- TfLiteSupportStatus::kInvalidArgumentError); -- } -- -- // Make sure the query distance in different places are consistent. -- if (config.has_partitioner()) { -- DistanceMeasure partitioner_measure = -- config.partitioner().query_distance(); -- if (measure != partitioner_measure) { -- return CreateStatusWithPayload( -- absl::StatusCode::kInvalidArgument, -- absl::StrFormat("DistanceMeasure %s is different from " -- "DistanceMeasure %s found in partitioner config.", -- DistanceMeasure_Name(measure), -- DistanceMeasure_Name(partitioner_measure)), -- TfLiteSupportStatus::kInvalidArgumentError); -- } -- } -- } -- return measure; --} -- --absl::Status ConvertEmbeddingToEigenMatrix(const Embedding& embedding, -- Eigen::MatrixXf* matrix) { -- if (embedding.feature_vector().value_float().empty()) { -- // This should be caught upstream at EmbeddingPostprocessor creation. -- return CreateStatusWithPayload(absl::StatusCode::kInternal, -- "Float query embedding is empty.", -- TfLiteSupportStatus::kError); -- } -- Eigen::Map<const Eigen::VectorXf> query_ptr( -- embedding.feature_vector().value_float().data(), -- embedding.feature_vector().value_float().size()); -- matrix->resize(embedding.feature_vector().value_float().size(), 1); -- matrix->col(0) = query_ptr; -- return absl::OkStatus(); --} -- --} // namespace -- --/* static */ --StatusOr<std::unique_ptr<SearchPostprocessor>> SearchPostprocessor::Create( -- TfLiteEngine* engine, -- int output_index, -- std::unique_ptr<SearchOptions> search_options, -- std::unique_ptr<EmbeddingOptions> embedding_options) { -- ASSIGN_OR_RETURN(auto embedding_postprocessor, -- CreateEmbeddingPostprocessor(engine, {output_index}, -- std::move(embedding_options))); -- -- ASSIGN_OR_RETURN(auto search_processor, -- Processor::Create<SearchPostprocessor>( -- /* num_expected_tensors =*/1, engine, {output_index}, -- /* requires_metadata =*/false)); -- -- RETURN_IF_ERROR(search_processor->Init(std::move(embedding_postprocessor), -- std::move(search_options))); -- return search_processor; --} -- --StatusOr<SearchResult> SearchPostprocessor::Postprocess() { -- // Extract embedding. -- Embedding embedding; -- RETURN_IF_ERROR(embedding_postprocessor_->Postprocess(&embedding)); -- // Convert embedding to Eigen matrix, as expected by ScaNN. -- Eigen::MatrixXf query; -- RETURN_IF_ERROR(ConvertEmbeddingToEigenMatrix(embedding, &query)); -- -- // Identify partitions to search. -- std::vector<std::vector<int>> leaves_to_search( -- 1, std::vector<int>(num_leaves_to_search_, -1)); -- if (!partitioner_->Partition(query, &leaves_to_search)) { -- return CreateStatusWithPayload(absl::StatusCode::kInternal, -- "Partitioning failed.", -- TfLiteSupportStatus::kError); -- } -- -- // Prepare search results. -- std::vector<TopN> top_n; -- top_n.emplace_back( -- options_->num_results(), -- std::make_pair(std::numeric_limits<float>::max(), kNoNeighborId)); -- // Perform search. -- if (quantizer_) { -- RETURN_IF_ERROR( -- QuantizedSearch(query, leaves_to_search[0], absl::MakeSpan(top_n))); -- } else { -- RETURN_IF_ERROR( -- LinearSearch(query, leaves_to_search[0], absl::MakeSpan(top_n))); -- } -- -- // Build results. -- SearchResult search_result; -- for (const auto& [distance, id] : top_n[0].Take()) { -- if (id == kNoNeighborId) { -- break; -- } -- ASSIGN_OR_RETURN(auto metadata, index_->GetMetadataAtIndex(id)); -- NearestNeighbor* nearest_neighbor = search_result.add_nearest_neighbors(); -- nearest_neighbor->set_distance(distance); -- nearest_neighbor->set_metadata(std::string(metadata)); -- } -- return search_result; --} -- --StatusOr<absl::string_view> SearchPostprocessor::GetUserInfo() { -- return index_->GetUserInfo(); --} -- --absl::Status SearchPostprocessor::Init( -- std::unique_ptr<EmbeddingPostprocessor> embedding_postprocessor, -- std::unique_ptr<SearchOptions> options) { -- embedding_postprocessor_ = std::move(embedding_postprocessor); -- RETURN_IF_ERROR(SanityCheckOptions(*options)); -- options_ = std::move(options); -- -- // Initialize index. -- ASSIGN_OR_RETURN( -- index_file_handler_, -- ExternalFileHandler::CreateFromExternalFile(&options_->index_file())); -- auto index_file_content = index_file_handler_->GetFileContent(); -- ASSIGN_OR_RETURN(index_, -- Index::CreateFromIndexBuffer(index_file_content.data(), -- index_file_content.size())); -- ASSIGN_OR_RETURN(index_config_, index_->GetIndexConfig()); -- RETURN_IF_ERROR(SanityCheckIndexConfig(index_config_)); -- // Get distance measure once and for all. -- ASSIGN_OR_RETURN(distance_measure_, -- GetDistanceMeasure(index_config_.scann_config())); -- -- // Initialize partitioner. -- if (index_config_.scann_config().has_partitioner()) { -- partitioner_ = tflite::scann_ondevice::core::Partitioner::Create( -- index_config_.scann_config().partitioner()); -- num_leaves_to_search_ = std::min( -- static_cast<int>(ceilf( -- partitioner_->NumPartitions() * -- index_config_.scann_config().partitioner().search_fraction())), -- partitioner_->NumPartitions()); -- } else { -- partitioner_ = -- absl::make_unique<tflite::scann_ondevice::core::NoOpPartitioner>(); -- num_leaves_to_search_ = partitioner_->NumPartitions(); -- } -- -- // Initialize product quantizer if needed. -- if (index_config_.scann_config().has_indexer()) { -- quantizer_ = tflite::scann_ondevice::core::AsymmetricHashQuerier::Create( -- index_config_.scann_config().indexer().asymmetric_hashing()); -- } -- -- return absl::OkStatus(); --} -- --absl::Status SearchPostprocessor::QuantizedSearch( -- Eigen::Ref<Eigen::MatrixXf> query, -- std::vector<int> leaves_to_search, -- absl::Span<TopN> top_n) { -- int dim = index_config_.embedding_dim(); -- // Prepare QueryInfo used for all leaves. -- QueryInfo query_info; -- if (!quantizer_->Process(query, &query_info)) { -- return CreateStatusWithPayload(absl::StatusCode::kInternal, -- "Query quantization failed.", -- TfLiteSupportStatus::kError); -- } -- for (int leaf_id : leaves_to_search) { -- // Load partition into Eigen matrix. -- ASSIGN_OR_RETURN(auto partition, index_->GetPartitionAtIndex(leaf_id)); -- int partition_size = partition.size() / dim; -- Eigen::Map<const Matrix8u> database( -- reinterpret_cast<const uint8_t*>(partition.data()), dim, -- partition_size); -- // Perform search. -- int global_offset = index_config_.global_partition_offsets(leaf_id); -- if (!AsymmetricHashFindNeighbors(query_info, database, global_offset, -- top_n)) { -- return CreateStatusWithPayload(absl::StatusCode::kInternal, -- "Nearest neighbor search failed.", -- TfLiteSupportStatus::kError); -- } -- } -- return absl::OkStatus(); --} -- --absl::Status SearchPostprocessor::LinearSearch( -- Eigen::Ref<Eigen::MatrixXf> query, -- std::vector<int> leaves_to_search, -- absl::Span<TopN> top_n) { -- int dim = index_config_.embedding_dim(); -- for (int leaf_id : leaves_to_search) { -- // Load partition into Eigen matrix. -- ASSIGN_OR_RETURN(auto partition, index_->GetPartitionAtIndex(leaf_id)); -- int partition_size = partition.size() / (dim * sizeof(float)); -- Eigen::Map<const Eigen::MatrixXf> database( -- reinterpret_cast<const float*>(partition.data()), dim, partition_size); -- // Perform search. -- int global_offset = index_config_.global_partition_offsets(leaf_id); -- if (!FloatFindNeighbors(query, database, global_offset, distance_measure_, -- top_n)) { -- return CreateStatusWithPayload(absl::StatusCode::kInternal, -- "Nearest neighbor search failed.", -- TfLiteSupportStatus::kError); -- } -- } -- return absl::OkStatus(); --} -- --} // namespace processor --} // namespace task --} // namespace tflite -diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/search_postprocessor.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/search_postprocessor.h -deleted file mode 100644 -index d79bc853148a9..0000000000000 ---- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/search_postprocessor.h -+++ /dev/null -@@ -1,112 +0,0 @@ --/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. -- --Licensed under the Apache License, Version 2.0 (the "License"); --you may not use this file except in compliance with the License. --You may obtain a copy of the License at -- -- http://www.apache.org/licenses/LICENSE-2.0 -- --Unless required by applicable law or agreed to in writing, software --distributed under the License is distributed on an "AS IS" BASIS, --WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. --See the License for the specific language governing permissions and --limitations under the License. --==============================================================================*/ -- --#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_PROCESSOR_SEARCH_POSTPROCESSOR_H_ --#define TENSORFLOW_LITE_SUPPORT_CC_TASK_PROCESSOR_SEARCH_POSTPROCESSOR_H_ -- --#include <cstdint> --#include <initializer_list> --#include <memory> --#include <vector> -- --#include "Eigen/Core" // from @eigen --#include "absl/strings/string_view.h" // from @com_google_absl --#include "absl/types/span.h" // from @com_google_absl --#include "tensorflow_lite_support/cc/port/statusor.h" --#include "tensorflow_lite_support/cc/task/core/external_file_handler.h" --#include "tensorflow_lite_support/cc/task/core/tflite_engine.h" --#include "tensorflow_lite_support/cc/task/processor/embedding_postprocessor.h" --#include "tensorflow_lite_support/cc/task/processor/processor.h" --#include "tensorflow_lite_support/cc/task/processor/proto/embedding_options.pb.h" --#include "tensorflow_lite_support/cc/task/processor/proto/search_options.pb.h" --#include "tensorflow_lite_support/cc/task/processor/proto/search_result.pb.h" --#include "tensorflow_lite_support/scann_ondevice/cc/core/partitioner.h" --#include "tensorflow_lite_support/scann_ondevice/cc/core/processor.h" --#include "tensorflow_lite_support/scann_ondevice/cc/core/searcher.h" --#include "tensorflow_lite_support/scann_ondevice/cc/core/serialized_searcher.pb.h" --#include "tensorflow_lite_support/scann_ondevice/cc/core/top_n_amortized_constant.h" --#include "tensorflow_lite_support/scann_ondevice/cc/index.h" --#include "tensorflow_lite_support/scann_ondevice/proto/index_config.pb.h" -- --namespace tflite { --namespace task { --namespace processor { -- --// Postprocessor in charge of performing embedding extraction followed by --// nearest-neighbor search. --// --// This postprocessor works with the following output tensor: --// (kTfLiteUInt8/kTfLiteFloat32) --// - `N` components corresponding to the `N` dimensions of the returned --// feature vector for this output layer. --// - Either 2 or 4 dimensions, i.e. `[1 x N]` or `[1 x 1 x 1 x N]`. --class SearchPostprocessor : public Postprocessor { -- public: -- static tflite::support::StatusOr<std::unique_ptr<SearchPostprocessor>> Create( -- tflite::task::core::TfLiteEngine* engine, -- int output_index, -- std::unique_ptr<SearchOptions> search_options, -- std::unique_ptr<EmbeddingOptions> embedding_options = -- std::make_unique<EmbeddingOptions>()); -- -- // Converts the tensor outputs to embeddings, then performs a nearest-neighbor -- // search in the index. -- tflite::support::StatusOr<SearchResult> Postprocess(); -- -- // Provides access to the opaque user info stored in the index file (if any), -- // in raw binary form. Returns an empty string if the index doesn't contain -- // user info. -- tflite::support::StatusOr<absl::string_view> GetUserInfo(); -- -- private: -- using Postprocessor::Postprocessor; -- -- absl::Status Init( -- std::unique_ptr<EmbeddingPostprocessor> embedding_postprocessor, -- std::unique_ptr<SearchOptions> options); -- -- absl::Status QuantizedSearch( -- Eigen::Ref<Eigen::MatrixXf> query, -- std::vector<int> leaves_to_search, -- absl::Span<tflite::scann_ondevice::core::TopN> top_n); -- absl::Status LinearSearch( -- Eigen::Ref<Eigen::MatrixXf> query, -- std::vector<int> leaves_to_search, -- absl::Span<tflite::scann_ondevice::core::TopN> top_n); -- -- std::unique_ptr<SearchOptions> options_; -- -- // Encapsulated EmbeddingPostprocessor converting raw tensors to embeddings. -- std::unique_ptr<EmbeddingPostprocessor> embedding_postprocessor_; -- -- // Index management. -- std::unique_ptr<tflite::task::core::ExternalFileHandler> index_file_handler_; -- std::unique_ptr<tflite::scann_ondevice::Index> index_; -- tflite::scann_ondevice::IndexConfig index_config_; -- -- // ScaNN management. -- int num_leaves_to_search_; -- tflite::scann_ondevice::core::DistanceMeasure distance_measure_; -- std::unique_ptr<tflite::scann_ondevice::core::PartitionerInterface> -- partitioner_; -- std::shared_ptr<tflite::scann_ondevice::core::AsymmetricHashQuerier> -- quantizer_; --}; -- --} // namespace processor --} // namespace task --} // namespace tflite -- --#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_PROCESSOR_SEARCH_POSTPROCESSOR_H_ -diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.mm b/third_party/tflite_support/src/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.mm -deleted file mode 100644 -index 2a11bb6730474..0000000000000 ---- a/third_party/tflite_support/src/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.mm -+++ /dev/null -@@ -1,26 +0,0 @@ --/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. -- --Licensed under the Apache License, Version 2.0 (the "License"); --you may not use this file except in compliance with the License. --You may obtain a copy of the License at -- -- http://www.apache.org/licenses/LICENSE-2.0 -- --Unless required by applicable law or agreed to in writing, software --distributed under the License is distributed on an "AS IS" BASIS, --WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. --See the License for the specific language governing permissions and --limitations under the License. --==============================================================================*/ --#import "third_party/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.h" -- --std::string MakeString(NSString* str) { -- return std::string([str UTF8String]); --} -- --NSString* MakeNSString(const std::string& str) { -- return [[NSString alloc] -- initWithBytes:const_cast<void*>(static_cast<const void*>(str.data())) -- length:str.length() -- encoding:NSUTF8StringEncoding]; --} -diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_populator.cc b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_populator.cc -deleted file mode 100644 -index 2841c730adfd1..0000000000000 ---- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_populator.cc -+++ /dev/null -@@ -1,150 +0,0 @@ --/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. -- --Licensed under the Apache License, Version 2.0 (the "License"); --you may not use this file except in compliance with the License. --You may obtain a copy of the License at -- -- http://www.apache.org/licenses/LICENSE-2.0 -- --Unless required by applicable law or agreed to in writing, software --distributed under the License is distributed on an "AS IS" BASIS, --WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. --See the License for the specific language governing permissions and --limitations under the License. --==============================================================================*/ -- --#include "tensorflow_lite_support/metadata/cc/metadata_populator.h" -- --#include <cstdlib> --#include <cstring> --#include <functional> -- --#include "contrib/minizip/ioapi.h" --#include "contrib/minizip/zip.h" --#include "flatbuffers/flatbuffers.h" // from @flatbuffers --#include "tensorflow/lite/schema/schema_generated.h" --#include "tensorflow_lite_support/cc/common.h" --#include "tensorflow_lite_support/cc/port/status_macros.h" --#include "tensorflow_lite_support/cc/port/statusor.h" --#include "tensorflow_lite_support/metadata/cc/utils/zip_mem_file.h" --#include "tensorflow_lite_support/metadata/metadata_schema_generated.h" -- --namespace tflite { --namespace metadata { -- --namespace { --constexpr char kMetadataBufferName[] = "TFLITE_METADATA"; -- --using ::absl::StatusCode; --using ::tflite::support::CreateStatusWithPayload; --using ::tflite::support::TfLiteSupportStatus; -- --} // namespace -- --ModelMetadataPopulator::ModelMetadataPopulator(const tflite::Model& model) { -- model.UnPackTo(&model_t_); --} -- --/* static */ --tflite::support::StatusOr<std::unique_ptr<ModelMetadataPopulator>> --ModelMetadataPopulator::CreateFromModelBuffer(const char* buffer_data, -- size_t buffer_size) { -- // Rely on the simplest, base flatbuffers verifier. Here is not the place to -- // e.g. use an OpResolver: we just want to make sure the buffer is valid to -- // access the metadata. -- flatbuffers::Verifier verifier = flatbuffers::Verifier( -- reinterpret_cast<const uint8_t*>(buffer_data), buffer_size); -- if (!tflite::VerifyModelBuffer(verifier)) { -- return CreateStatusWithPayload( -- StatusCode::kInvalidArgument, -- "The model is not a valid FlatBuffer buffer.", -- TfLiteSupportStatus::kInvalidFlatBufferError); -- } -- // Use absl::WrapUnique() to call private constructor: -- // https://abseil.io/tips/126. -- return absl::WrapUnique( -- new ModelMetadataPopulator(*tflite::GetModel(buffer_data))); --} -- --void ModelMetadataPopulator::LoadMetadata(const char* metadata_buffer_data, -- size_t metadata_buffer_size) { -- // Pack the model metadata in a buffer. -- auto model_metadata_buffer = std::make_unique<tflite::BufferT>(); -- model_metadata_buffer->data = {metadata_buffer_data, -- metadata_buffer_data + metadata_buffer_size}; -- // Check if the model already has metadata. If so, just override the buffer -- // and exit. -- for (const auto& metadata_t : model_t_.metadata) { -- if (metadata_t->name == kMetadataBufferName) { -- model_t_.buffers[metadata_t->buffer] = std::move(model_metadata_buffer); -- return; -- } -- } -- // Model doesn't already have metadata: add metadata buffer and pointer to the -- // buffer in the model metadata section. -- model_t_.buffers.push_back(std::move(model_metadata_buffer)); -- auto metadata_t = std::make_unique<tflite::MetadataT>(); -- metadata_t->name = kMetadataBufferName; -- metadata_t->buffer = model_t_.buffers.size() - 1; -- model_t_.metadata.push_back(std::move(metadata_t)); --} -- --void ModelMetadataPopulator::LoadAssociatedFiles( -- const absl::flat_hash_map<std::string, std::string>& associated_files) { -- associated_files_ = associated_files; --} -- --tflite::support::StatusOr<std::string> --ModelMetadataPopulator::AppendAssociatedFiles(const char* model_buffer_data, -- size_t model_buffer_size) { -- // Create in-memory zip file. -- ZipMemFile mem_file = ZipMemFile(model_buffer_data, model_buffer_size); -- // Open zip. -- zipFile zf = zipOpen2(/*pathname=*/nullptr, APPEND_STATUS_CREATEAFTER, -- /*globalcomment=*/nullptr, &mem_file.GetFileFuncDef()); -- if (zf == nullptr) { -- return CreateStatusWithPayload( -- StatusCode::kUnknown, "Unable to open zip archive", -- TfLiteSupportStatus::kMetadataAssociatedFileZipError); -- } -- // Write associated files. -- for (const auto& [name, contents] : associated_files_) { -- if ((zipOpenNewFileInZip(zf, name.c_str(), -- /*zipfi=*/nullptr, -- /*extrafield_local=*/nullptr, -- /*size_extrafield_local=*/0, -- /*extrafield_global=*/nullptr, -- /*size_extrafield_global=*/0, -- /*comment=*/nullptr, -- /*method=*/0, -- /*level=*/Z_DEFAULT_COMPRESSION) != ZIP_OK) || -- (zipWriteInFileInZip(zf, contents.data(), contents.length()) != -- ZIP_OK) || -- (zipCloseFileInZip(zf) != ZIP_OK)) { -- return CreateStatusWithPayload( -- StatusCode::kUnknown, "Unable to write file to zip archive", -- TfLiteSupportStatus::kMetadataAssociatedFileZipError); -- } -- } -- // Close zip. -- if (zipClose(zf, /*global_comment=*/nullptr) != ZIP_OK) { -- return CreateStatusWithPayload( -- StatusCode::kUnknown, "Unable to close zip archive", -- TfLiteSupportStatus::kMetadataAssociatedFileZipError); -- } -- // Return as a string. -- return std::string(mem_file.GetFileContent()); --} -- --tflite::support::StatusOr<std::string> ModelMetadataPopulator::Populate() { -- // Build model. -- flatbuffers::FlatBufferBuilder model_fbb; -- model_fbb.Finish(tflite::Model::Pack(model_fbb, &model_t_), -- tflite::ModelIdentifier()); -- return AppendAssociatedFiles( -- reinterpret_cast<char*>(model_fbb.GetBufferPointer()), -- model_fbb.GetSize()); --} -- --} // namespace metadata --} // namespace tflite -diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_mem_file.cc b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_mem_file.cc -deleted file mode 100644 -index f2b07e2054dfb..0000000000000 ---- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_mem_file.cc -+++ /dev/null -@@ -1,134 +0,0 @@ --/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. -- --Licensed under the Apache License, Version 2.0 (the "License"); --you may not use this file except in compliance with the License. --You may obtain a copy of the License at -- -- http://www.apache.org/licenses/LICENSE-2.0 -- --Unless required by applicable law or agreed to in writing, software --distributed under the License is distributed on an "AS IS" BASIS, --WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. --See the License for the specific language governing permissions and --limitations under the License. --==============================================================================*/ -- --#include "tensorflow_lite_support/metadata/cc/utils/zip_mem_file.h" -- --#include <algorithm> --#include <cstdio> -- --#include "absl/strings/string_view.h" // from @com_google_absl --#include "contrib/minizip/ioapi.h" -- --namespace tflite { --namespace metadata { -- --ZipMemFile::ZipMemFile(const char* buffer, size_t size) -- : data_(buffer, size), offset_(0) { -- zlib_filefunc_def_.zopen_file = OpenFile; -- zlib_filefunc_def_.zread_file = ReadFile; -- zlib_filefunc_def_.zwrite_file = WriteFile; -- zlib_filefunc_def_.ztell_file = TellFile; -- zlib_filefunc_def_.zseek_file = SeekFile; -- zlib_filefunc_def_.zclose_file = CloseFile; -- zlib_filefunc_def_.zerror_file = ErrorFile; -- zlib_filefunc_def_.opaque = this; --} -- --zlib_filefunc_def& ZipMemFile::GetFileFuncDef() { -- return zlib_filefunc_def_; --} -- --absl::string_view ZipMemFile::GetFileContent() const { -- return data_; --} -- --/* static */ --voidpf ZipMemFile::OpenFile(voidpf opaque, const char* filename, int mode) { -- // Result is never used, but needs to be non-null for `zipOpen2` not to fail. -- return opaque; --} -- --/* static */ --size_t ZipMemFile::ReadFile(voidpf opaque, -- voidpf stream, -- void* buf, -- size_t size) { -- auto* mem_file = static_cast<ZipMemFile*>(opaque); -- if (mem_file->offset_ < 0 || mem_file->Size() < mem_file->offset_) { -- return 0; -- } -- if (mem_file->offset_ + size > mem_file->Size()) { -- size = mem_file->Size() - mem_file->offset_; -- } -- memcpy(buf, -- static_cast<const char*>(mem_file->data_.c_str()) + mem_file->offset_, -- size); -- mem_file->offset_ += size; -- return size; --} -- --/* static */ --size_t ZipMemFile::WriteFile(voidpf opaque, -- voidpf stream, -- const void* buf, -- size_t size) { -- auto* mem_file = static_cast<ZipMemFile*>(opaque); -- if (mem_file->offset_ + size > mem_file->Size()) { -- mem_file->data_.resize(mem_file->offset_ + size); -- } -- mem_file->data_.replace(mem_file->offset_, size, -- static_cast<const char*>(buf), size); -- mem_file->offset_ += size; -- return size; --} -- --/* static */ --ptrdiff_t ZipMemFile::TellFile(voidpf opaque, voidpf stream) { -- return static_cast<ZipMemFile*>(opaque)->offset_; --} -- --/* static */ --ptrdiff_t ZipMemFile::SeekFile(voidpf opaque, -- voidpf stream, -- size_t offset, -- int origin) { -- auto* mem_file = static_cast<ZipMemFile*>(opaque); -- switch (origin) { -- case SEEK_SET: -- mem_file->offset_ = offset; -- return 0; -- case SEEK_CUR: -- if (mem_file->offset_ + offset < 0 || -- mem_file->offset_ + offset > mem_file->Size()) { -- return -1; -- } -- mem_file->offset_ += offset; -- return 0; -- case SEEK_END: -- if (mem_file->Size() - offset < 0 || -- mem_file->Size() - offset > mem_file->Size()) { -- return -1; -- } -- mem_file->offset_ = offset + mem_file->Size(); -- return 0; -- default: -- return -1; -- } --} -- --/* static */ --int ZipMemFile::CloseFile(voidpf opaque, voidpf stream) { -- // Nothing to do. -- return 0; --} -- --/* static */ --int ZipMemFile::ErrorFile(voidpf opaque, voidpf stream) { -- // Unused. -- return 0; --} -- --} // namespace metadata --} // namespace tflite -diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_mem_file.h b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_mem_file.h -deleted file mode 100644 -index d6748fcbe9ee1..0000000000000 ---- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_mem_file.h -+++ /dev/null -@@ -1,75 +0,0 @@ --/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. -- --Licensed under the Apache License, Version 2.0 (the "License"); --you may not use this file except in compliance with the License. --You may obtain a copy of the License at -- -- http://www.apache.org/licenses/LICENSE-2.0 -- --Unless required by applicable law or agreed to in writing, software --distributed under the License is distributed on an "AS IS" BASIS, --WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. --See the License for the specific language governing permissions and --limitations under the License. --==============================================================================*/ -- --#ifndef TENSORFLOW_LITE_SUPPORT_METADATA_CC_UTILS_ZIP_MEM_FILE_H_ --#define TENSORFLOW_LITE_SUPPORT_METADATA_CC_UTILS_ZIP_MEM_FILE_H_ -- --#include <cstdlib> -- --#include "absl/strings/string_view.h" // from @com_google_absl --#include "contrib/minizip/ioapi.h" -- --namespace tflite { --namespace metadata { -- --// In-memory zip file implementation. --// --// Adapted from [1], with a few key differences: --// * backed by an `std::string` instead of malloc-ed C buffers, --// * supports opening the file for writing through `zipOpen2`. --// --// [1]: --// https://github.com/google/libkml/blob/master/third_party/zlib-1.2.3/contrib/minizip/iomem_simple.c --class ZipMemFile { -- public: -- // Constructs an in-memory zip file from a buffer. -- ZipMemFile(const char* buffer, size_t size); -- // Provides access to the `zlib_filefunc_def` implementation for the in-memory -- // zip file. -- zlib_filefunc_def& GetFileFuncDef(); -- // Provides access to the file contents. -- absl::string_view GetFileContent() const; -- -- private: -- // The string backing the in-memory file. -- std::string data_; -- // The current offset in the file. -- size_t offset_; -- // The `zlib_filefunc_def` implementation for this in-memory zip file. -- zlib_filefunc_def zlib_filefunc_def_; -- -- // Convenience function to access the current data size. -- size_t Size() const { return data_.size(); } -- -- // The file function implementations used in the `zlib_filefunc_def`. -- static voidpf OpenFile(voidpf opaque, const char* filename, int mode); -- static size_t ReadFile(voidpf opaque, voidpf stream, void* buf, size_t size); -- static size_t WriteFile(voidpf opaque, -- voidpf stream, -- const void* buf, -- size_t size); -- static ptrdiff_t TellFile(voidpf opaque, voidpf stream); -- static ptrdiff_t SeekFile(voidpf opaque, -- voidpf stream, -- size_t offset, -- int origin); -- static int CloseFile(voidpf opaque, voidpf stream); -- static int ErrorFile(voidpf opaque, voidpf stream); --}; -- --} // namespace metadata --} // namespace tflite -- --#endif // TENSORFLOW_LITE_SUPPORT_METADATA_CC_UTILS_ZIP_MEM_FILE_H_ -diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/pybinds/_pywrap_audio_classifier.cc b/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/pybinds/_pywrap_audio_classifier.cc -deleted file mode 100644 -index b5969e4e82bf8..0000000000000 ---- a/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/pybinds/_pywrap_audio_classifier.cc -+++ /dev/null -@@ -1,84 +0,0 @@ --/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. -- --Licensed under the Apache License, Version 2.0 (the "License"); --you may not use this file except in compliance with the License. --You may obtain a copy of the License at -- -- http://www.apache.org/licenses/LICENSE-2.0 -- --Unless required by applicable law or agreed to in writing, software --distributed under the License is distributed on an "AS IS" BASIS, --WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. --See the License for the specific language governing permissions and --limitations under the License. --==============================================================================*/ -- --#include "pybind11/pybind11.h" --#include "pybind11_protobuf/native_proto_caster.h" // from @pybind11_protobuf --#include "tensorflow_lite_support/cc/task/audio/audio_classifier.h" --#include "tensorflow_lite_support/cc/task/audio/core/audio_buffer.h" --#include "tensorflow_lite_support/cc/task/processor/proto/classification_options.pb.h" --#include "tensorflow_lite_support/python/task/core/pybinds/task_utils.h" -- --namespace tflite { --namespace task { --namespace audio { -- --namespace { --namespace py = ::pybind11; --using PythonBaseOptions = ::tflite::python::task::core::BaseOptions; --using CppBaseOptions = ::tflite::task::core::BaseOptions; --} // namespace -- --PYBIND11_MODULE(_pywrap_audio_classifier, m) { -- // python wrapper for C++ AudioClassifier class which shouldn't be directly -- // used by the users. -- pybind11_protobuf::ImportNativeProtoCasters(); -- -- py::class_<AudioClassifier>(m, "AudioClassifier") -- .def_static( -- "create_from_options", -- [](const PythonBaseOptions& base_options, -- const processor::ClassificationOptions& classification_options) { -- AudioClassifierOptions options; -- auto cpp_base_options = -- core::convert_to_cpp_base_options(base_options); -- options.set_allocated_base_options(cpp_base_options.release()); -- -- if (classification_options.has_display_names_locale()) { -- options.set_display_names_locale( -- classification_options.display_names_locale()); -- } -- if (classification_options.has_max_results()) { -- options.set_max_results(classification_options.max_results()); -- } -- if (classification_options.has_score_threshold()) { -- options.set_score_threshold( -- classification_options.score_threshold()); -- } -- options.mutable_class_name_allowlist()->CopyFrom( -- classification_options.class_name_allowlist()); -- options.mutable_class_name_denylist()->CopyFrom( -- classification_options.class_name_denylist()); -- -- auto classifier = AudioClassifier::CreateFromOptions(options); -- return core::get_value(classifier); -- }) -- .def("classify", -- [](AudioClassifier& self, -- const AudioBuffer& audio_buffer) -> ClassificationResult { -- auto classification_result = self.Classify(audio_buffer); -- return core::get_value(classification_result); -- }) -- .def("get_required_audio_format", -- [](AudioClassifier& self) -> AudioBuffer::AudioFormat { -- auto audio_format = self.GetRequiredAudioFormat(); -- return core::get_value(audio_format); -- }) -- .def("get_required_input_buffer_size", -- &AudioClassifier::GetRequiredInputBufferSize); --} -- --} // namespace audio --} // namespace task --} // namespace tflite -diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/pybinds/_pywrap_audio_embedder.cc b/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/pybinds/_pywrap_audio_embedder.cc -deleted file mode 100644 -index 8b1d67d9f8e05..0000000000000 ---- a/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/pybinds/_pywrap_audio_embedder.cc -+++ /dev/null -@@ -1,78 +0,0 @@ --/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. -- --Licensed under the Apache License, Version 2.0 (the "License"); --you may not use this file except in compliance with the License. --You may obtain a copy of the License at -- -- http://www.apache.org/licenses/LICENSE-2.0 -- --Unless required by applicable law or agreed to in writing, software --distributed under the License is distributed on an "AS IS" BASIS, --WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. --See the License for the specific language governing permissions and --limitations under the License. --==============================================================================*/ -- --#include "pybind11/pybind11.h" --#include "pybind11_protobuf/native_proto_caster.h" // from @pybind11_protobuf --#include "tensorflow_lite_support/cc/task/audio/audio_embedder.h" --#include "tensorflow_lite_support/cc/task/audio/core/audio_buffer.h" --#include "tensorflow_lite_support/cc/task/processor/proto/embedding.pb.h" --#include "tensorflow_lite_support/python/task/core/pybinds/task_utils.h" -- --namespace tflite { --namespace task { --namespace audio { -- --namespace { --namespace py = ::pybind11; --using PythonBaseOptions = ::tflite::python::task::core::BaseOptions; --using CppBaseOptions = ::tflite::task::core::BaseOptions; --} // namespace -- --PYBIND11_MODULE(_pywrap_audio_embedder, m) { -- // python wrapper for C++ AudioEmbedder class which shouldn't be directly used -- // by the users. -- pybind11_protobuf::ImportNativeProtoCasters(); -- -- py::class_<AudioEmbedder>(m, "AudioEmbedder") -- .def_static( -- "create_from_options", -- [](const PythonBaseOptions& base_options, -- const processor::EmbeddingOptions& embedding_options) { -- AudioEmbedderOptions options; -- auto cpp_base_options = -- core::convert_to_cpp_base_options(base_options); -- -- options.set_allocated_base_options(cpp_base_options.release()); -- options.add_embedding_options()->CopyFrom(embedding_options); -- auto embedder = AudioEmbedder::CreateFromOptions(options); -- return core::get_value(embedder); -- }) -- .def_static("cosine_similarity", -- [](const processor::FeatureVector& u, -- const processor::FeatureVector& v) -> double { -- auto similarity = AudioEmbedder::CosineSimilarity(u, v); -- return core::get_value(similarity); -- }) -- .def("embed", -- [](AudioEmbedder& self, -- const AudioBuffer& audio_buffer) -> processor::EmbeddingResult { -- auto embedding_result = self.Embed(audio_buffer); -- return core::get_value(embedding_result); -- }) -- .def("get_embedding_dimension", &AudioEmbedder::GetEmbeddingDimension) -- .def("get_number_of_output_layers", -- &AudioEmbedder::GetNumberOfOutputLayers) -- .def("get_required_audio_format", -- [](AudioEmbedder& self) -> AudioBuffer::AudioFormat { -- auto audio_format = self.GetRequiredAudioFormat(); -- return core::get_value(audio_format); -- }) -- .def("get_required_input_buffer_size", -- &AudioEmbedder::GetRequiredInputBufferSize); --} -- --} // namespace audio --} // namespace task --} // namespace tflite -diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/text/pybinds/_pywrap_text_embedder.cc b/third_party/tflite_support/src/tensorflow_lite_support/python/task/text/pybinds/_pywrap_text_embedder.cc -deleted file mode 100644 -index e148bdb773655..0000000000000 ---- a/third_party/tflite_support/src/tensorflow_lite_support/python/task/text/pybinds/_pywrap_text_embedder.cc -+++ /dev/null -@@ -1,68 +0,0 @@ --/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. -- --Licensed under the Apache License, Version 2.0 (the "License"); --you may not use this file except in compliance with the License. --You may obtain a copy of the License at -- -- http://www.apache.org/licenses/LICENSE-2.0 -- --Unless required by applicable law or agreed to in writing, software --distributed under the License is distributed on an "AS IS" BASIS, --WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. --See the License for the specific language governing permissions and --limitations under the License. --==============================================================================*/ -- --#include "pybind11/pybind11.h" --#include "pybind11_protobuf/native_proto_caster.h" // from @pybind11_protobuf --#include "tensorflow_lite_support/cc/task/text/text_embedder.h" --#include "tensorflow_lite_support/python/task/core/pybinds/task_utils.h" -- --namespace tflite { --namespace task { --namespace text { -- --namespace { --using PythonBaseOptions = ::tflite::python::task::core::BaseOptions; --using CppBaseOptions = ::tflite::task::core::BaseOptions; --} // namespace -- --PYBIND11_MODULE(_pywrap_text_embedder, m) { -- // python wrapper for C++ TextEmbeder class which shouldn't be directly used -- // by the users. -- pybind11_protobuf::ImportNativeProtoCasters(); -- -- pybind11::class_<TextEmbedder>(m, "TextEmbedder") -- .def_static( -- "create_from_options", -- [](const PythonBaseOptions& base_options, -- const processor::EmbeddingOptions& embedding_options) { -- TextEmbedderOptions options; -- auto cpp_base_options = -- core::convert_to_cpp_base_options(base_options); -- -- options.set_allocated_base_options(cpp_base_options.release()); -- options.add_embedding_options()->CopyFrom(embedding_options); -- auto embedder = TextEmbedder::CreateFromOptions(options); -- return core::get_value(embedder); -- }) -- .def("embed", -- [](TextEmbedder& self, -- const std::string& text) -> processor::EmbeddingResult { -- auto embedding_result = self.Embed(text); -- return core::get_value(embedding_result); -- }) -- .def("get_embedding_dimension", &TextEmbedder::GetEmbeddingDimension) -- .def("get_number_of_output_layers", -- &TextEmbedder::GetNumberOfOutputLayers) -- .def_static("cosine_similarity", -- [](const processor::FeatureVector& u, -- const processor::FeatureVector& v) -> double { -- auto similarity = TextEmbedder::CosineSimilarity(u, v); -- return core::get_value(similarity); -- }); --} -- --} // namespace text --} // namespace task --} // namespace tflite -diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/core/pybinds/image_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/core/pybinds/image_utils.cc -deleted file mode 100644 -index 3b6bf2fc44dc5..0000000000000 ---- a/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/core/pybinds/image_utils.cc -+++ /dev/null -@@ -1,68 +0,0 @@ --/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. -- --Licensed under the Apache License, Version 2.0 (the "License"); --you may not use this file except in compliance with the License. --You may obtain a copy of the License at -- -- http://www.apache.org/licenses/LICENSE-2.0 -- --Unless required by applicable law or agreed to in writing, software --distributed under the License is distributed on an "AS IS" BASIS, --WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. --See the License for the specific language governing permissions and --limitations under the License. --==============================================================================*/ --#include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h" -- --#include "pybind11/pybind11.h" --#include "pybind11_abseil/status_casters.h" // from @pybind11_abseil -- --namespace tflite { --namespace task { --namespace vision { -- --namespace { --namespace py = ::pybind11; -- --} // namespace -- --PYBIND11_MODULE(image_utils, m) { -- // python wrapper for ImageData class which shouldn't be directly used by -- // the users. -- pybind11::google::ImportStatusModule(); -- -- py::class_<ImageData>(m, "ImageData", py::buffer_protocol()) -- .def(py::init([](py::buffer buffer) { -- py::buffer_info info = buffer.request(); -- -- if (info.ndim != 2 && info.ndim != 3) { -- throw py::value_error("Incompatible buffer dimension!"); -- } -- -- int height = info.shape[0]; -- int width = info.shape[1]; -- int channels = info.ndim == 3 ? info.shape[2] : 1; -- -- return ImageData{static_cast<uint8*>(info.ptr), width, height, -- channels}; -- })) -- .def_readonly("width", &ImageData::width) -- .def_readonly("height", &ImageData::height) -- .def_readonly("channels", &ImageData::channels) -- .def_buffer([](ImageData& data) -> py::buffer_info { -- return py::buffer_info( -- data.pixel_data, sizeof(uint8), -- py::format_descriptor<uint8>::format(), 3, -- {data.height, data.width, data.channels}, -- {sizeof(uint8) * size_t(data.width) * size_t(data.channels), -- sizeof(uint8) * size_t(data.channels), sizeof(uint8)}); -- }); -- -- m.def("DecodeImageFromFile", &DecodeImageFromFile); -- m.def("EncodeImageToPngFile", &EncodeImageToPngFile); -- m.def("ImageDataFree", &ImageDataFree); --} -- --} // namespace vision --} // namespace task --} // namespace tflite -diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_image_classifier.cc b/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_image_classifier.cc -deleted file mode 100644 -index f3f478d6f4f74..0000000000000 ---- a/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_image_classifier.cc -+++ /dev/null -@@ -1,108 +0,0 @@ --/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. -- --Licensed under the Apache License, Version 2.0 (the "License"); --you may not use this file except in compliance with the License. --You may obtain a copy of the License at -- -- http://www.apache.org/licenses/LICENSE-2.0 -- --Unless required by applicable law or agreed to in writing, software --distributed under the License is distributed on an "AS IS" BASIS, --WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. --See the License for the specific language governing permissions and --limitations under the License. --==============================================================================*/ -- --#include "pybind11/pybind11.h" --#include "pybind11_protobuf/native_proto_caster.h" // from @pybind11_protobuf --#include "tensorflow_lite_support/cc/task/processor/proto/bounding_box.pb.h" --#include "tensorflow_lite_support/cc/task/processor/proto/classification_options.pb.h" --#include "tensorflow_lite_support/cc/task/processor/proto/classifications.pb.h" --#include "tensorflow_lite_support/cc/task/vision/image_classifier.h" --#include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h" --#include "tensorflow_lite_support/python/task/core/pybinds/task_utils.h" -- --namespace tflite { --namespace task { --namespace vision { -- --namespace { --namespace py = ::pybind11; --using PythonBaseOptions = ::tflite::python::task::core::BaseOptions; --using CppBaseOptions = ::tflite::task::core::BaseOptions; --} // namespace -- --PYBIND11_MODULE(_pywrap_image_classifier, m) { -- // python wrapper for C++ ImageClassifier class which shouldn't be directly -- // used by the users. -- pybind11_protobuf::ImportNativeProtoCasters(); -- -- py::class_<ImageClassifier>(m, "ImageClassifier") -- .def_static( -- "create_from_options", -- [](const PythonBaseOptions& base_options, -- const processor::ClassificationOptions& classification_options) { -- ImageClassifierOptions options; -- auto cpp_base_options = -- core::convert_to_cpp_base_options(base_options); -- options.set_allocated_base_options(cpp_base_options.release()); -- -- if (classification_options.has_display_names_locale()) { -- options.set_display_names_locale( -- classification_options.display_names_locale()); -- } -- if (classification_options.has_max_results()) { -- options.set_max_results(classification_options.max_results()); -- } -- if (classification_options.has_score_threshold()) { -- options.set_score_threshold( -- classification_options.score_threshold()); -- } -- options.mutable_class_name_whitelist()->CopyFrom( -- classification_options.class_name_allowlist()); -- options.mutable_class_name_blacklist()->CopyFrom( -- classification_options.class_name_denylist()); -- -- auto classifier = ImageClassifier::CreateFromOptions(options); -- return core::get_value(classifier); -- }) -- .def("classify", -- [](ImageClassifier& self, -- const ImageData& image_data) -> processor::ClassificationResult { -- auto frame_buffer = CreateFrameBufferFromImageData(image_data); -- auto vision_classification_result = -- self.Classify(*core::get_value(frame_buffer)); -- // Convert from vision::ClassificationResult to -- // processor::ClassificationResult as required by the Python layer. -- processor::ClassificationResult classification_result; -- classification_result.ParseFromString( -- core::get_value(vision_classification_result) -- .SerializeAsString()); -- return classification_result; -- }) -- .def("classify", -- [](ImageClassifier& self, const ImageData& image_data, -- const processor::BoundingBox& bounding_box) -- -> processor::ClassificationResult { -- // Convert from processor::BoundingBox to vision::BoundingBox as -- // the latter is used in the C++ layer. -- BoundingBox vision_bounding_box; -- vision_bounding_box.ParseFromString( -- bounding_box.SerializeAsString()); -- -- auto frame_buffer = CreateFrameBufferFromImageData(image_data); -- auto vision_classification_result = self.Classify( -- *core::get_value(frame_buffer), vision_bounding_box); -- // Convert from vision::ClassificationResult to -- // processor::ClassificationResult as required by the Python layer. -- processor::ClassificationResult classification_result; -- classification_result.ParseFromString( -- core::get_value(vision_classification_result) -- .SerializeAsString()); -- return classification_result; -- }); --} -- --} // namespace vision --} // namespace task --} // namespace tflite -diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_image_embedder.cc b/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_image_embedder.cc -deleted file mode 100644 -index 91c8db80ffaba..0000000000000 ---- a/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_image_embedder.cc -+++ /dev/null -@@ -1,145 +0,0 @@ --/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. -- --Licensed under the Apache License, Version 2.0 (the "License"); --you may not use this file except in compliance with the License. --You may obtain a copy of the License at -- -- http://www.apache.org/licenses/LICENSE-2.0 -- --Unless required by applicable law or agreed to in writing, software --distributed under the License is distributed on an "AS IS" BASIS, --WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. --See the License for the specific language governing permissions and --limitations under the License. --==============================================================================*/ -- --#include <stdexcept> -- --#include "pybind11/pybind11.h" --#include "pybind11_protobuf/native_proto_caster.h" // from @pybind11_protobuf --#include "tensorflow_lite_support/cc/task/processor/proto/bounding_box.pb.h" --#include "tensorflow_lite_support/cc/task/processor/proto/embedding.pb.h" --#include "tensorflow_lite_support/cc/task/vision/image_embedder.h" --#include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h" --#include "tensorflow_lite_support/python/task/core/pybinds/task_utils.h" -- --namespace tflite { --namespace task { --namespace vision { -- --namespace { --namespace py = ::pybind11; --using PythonBaseOptions = ::tflite::python::task::core::BaseOptions; --using CppBaseOptions = ::tflite::task::core::BaseOptions; --} // namespace -- --PYBIND11_MODULE(_pywrap_image_embedder, m) { -- // python wrapper for C++ ImageEmbeder class which shouldn't be directly used -- // by the users. -- pybind11_protobuf::ImportNativeProtoCasters(); -- -- py::class_<ImageEmbedder>(m, "ImageEmbedder") -- .def_static( -- "create_from_options", -- [](const PythonBaseOptions& base_options, -- const processor::EmbeddingOptions& embedding_options) { -- ImageEmbedderOptions options; -- if (base_options.has_file_content()) { -- options.mutable_model_file_with_metadata()->set_file_content( -- base_options.file_content()); -- } -- if (base_options.has_file_name()) { -- options.mutable_model_file_with_metadata()->set_file_name( -- base_options.file_name()); -- } -- -- options.set_num_threads(base_options.num_threads()); -- if (base_options.use_coral()) { -- options.mutable_compute_settings() -- ->mutable_tflite_settings() -- ->set_delegate(tflite::proto::Delegate::EDGETPU_CORAL); -- } -- -- if (embedding_options.has_l2_normalize()) { -- options.set_l2_normalize(embedding_options.l2_normalize()); -- } -- if (embedding_options.has_quantize()) { -- options.set_quantize(embedding_options.quantize()); -- } -- auto embedder = ImageEmbedder::CreateFromOptions(options); -- return get_value(embedder); -- }) -- .def("embed", -- [](ImageEmbedder& self, -- const ImageData& image_data) -> processor::EmbeddingResult { -- auto frame_buffer = CreateFrameBufferFromImageData(image_data); -- auto vision_embedding_result = -- self.Embed(*core::get_value(frame_buffer)); -- // Convert from vision::EmbeddingResult to -- // processor::EmbeddingResult -- processor::EmbeddingResult embedding_result; -- embedding_result.ParseFromString( -- core::get_value(vision_embedding_result).SerializeAsString()); -- return embedding_result; -- }) -- .def("embed", -- [](ImageEmbedder& self, const ImageData& image_data, -- const processor::BoundingBox& bounding_box) -- -> processor::EmbeddingResult { -- // Convert from processor::BoundingBox to vision::BoundingBox as -- // the later is used in the C++ layer. -- BoundingBox vision_bounding_box; -- vision_bounding_box.ParseFromString( -- bounding_box.SerializeAsString()); -- -- auto frame_buffer = CreateFrameBufferFromImageData(image_data); -- auto vision_embedding_result = self.Embed( -- *core::get_value(frame_buffer), vision_bounding_box); -- // Convert from vision::EmbeddingResult to -- // processor::EmbeddingResult as required by the Python layer. -- processor::EmbeddingResult embedding_result; -- embedding_result.ParseFromString( -- core::get_value(vision_embedding_result).SerializeAsString()); -- return embedding_result; -- }) -- .def("get_embedding_by_index", -- [](ImageEmbedder& self, -- const processor::EmbeddingResult& embedding_result, -- const int index) -> processor::Embedding { -- // Convert from processor::EmbeddingResult to -- // vision::EmbeddingResult as the latter is used in the C++ API. -- EmbeddingResult vision_embedding_result; -- vision_embedding_result.ParseFromString( -- embedding_result.SerializeAsString()); -- -- Embedding vision_embedding{ -- self.GetEmbeddingByIndex(vision_embedding_result, index)}; -- // Convert from vision::Embedding to processor::Embedding -- // as required by the Python layer. -- processor::Embedding embedding; -- embedding.ParseFromString(vision_embedding.SerializeAsString()); -- return embedding; -- }) -- .def("get_number_of_output_layers", -- &ImageEmbedder::GetNumberOfOutputLayers) -- .def("get_embedding_dimension", &ImageEmbedder::GetEmbeddingDimension) -- .def_static( -- "cosine_similarity", -- [](const processor::FeatureVector& u, -- const processor::FeatureVector& v) -> double { -- // Convert from processor::FeatureVector to -- // vision::FeatureVector as the latter is used in the C++ -- // layer. -- FeatureVector vision_feature_vector_u; -- vision_feature_vector_u.ParseFromString(u.SerializeAsString()); -- FeatureVector vision_feature_vector_v; -- vision_feature_vector_v.ParseFromString(v.SerializeAsString()); -- auto similarity = ImageEmbedder::CosineSimilarity( -- vision_feature_vector_u, vision_feature_vector_v); -- return core::get_value(similarity); -- }); --} -- --} // namespace vision --} // namespace task --} // namespace tflite -diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_image_segmenter.cc b/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_image_segmenter.cc -deleted file mode 100644 -index 19d1f31b2e78c..0000000000000 ---- a/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_image_segmenter.cc -+++ /dev/null -@@ -1,73 +0,0 @@ --/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. -- --Licensed under the Apache License, Version 2.0 (the "License"); --you may not use this file except in compliance with the License. --You may obtain a copy of the License at -- -- http://www.apache.org/licenses/LICENSE-2.0 -- --Unless required by applicable law or agreed to in writing, software --distributed under the License is distributed on an "AS IS" BASIS, --WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. --See the License for the specific language governing permissions and --limitations under the License. --==============================================================================*/ -- --#include "pybind11/pybind11.h" --#include "pybind11_protobuf/native_proto_caster.h" // from @pybind11_protobuf --#include "tensorflow_lite_support/cc/task/processor/proto/segmentation_options.pb.h" --#include "tensorflow_lite_support/cc/task/vision/image_segmenter.h" --#include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h" --#include "tensorflow_lite_support/python/task/core/pybinds/task_utils.h" -- --namespace tflite { --namespace task { --namespace vision { -- --namespace { --namespace py = ::pybind11; --using PythonBaseOptions = ::tflite::python::task::core::BaseOptions; --using CppBaseOptions = ::tflite::task::core::BaseOptions; --} // namespace -- --PYBIND11_MODULE(_pywrap_image_segmenter, m) { -- // python wrapper for C++ ImageSegmenter class which shouldn't be directly -- // used by the users. -- pybind11_protobuf::ImportNativeProtoCasters(); -- -- py::class_<ImageSegmenter>(m, "ImageSegmenter") -- .def_static( -- "create_from_options", -- [](const PythonBaseOptions& base_options, -- const processor::SegmentationOptions& segmentation_options) { -- ImageSegmenterOptions options; -- auto cpp_base_options = -- core::convert_to_cpp_base_options(base_options); -- options.set_allocated_base_options(cpp_base_options.release()); -- -- if (segmentation_options.has_display_names_locale()) { -- options.set_display_names_locale( -- segmentation_options.display_names_locale()); -- } -- if (segmentation_options.has_output_type()) { -- options.set_output_type( -- static_cast<ImageSegmenterOptions::OutputType>( -- segmentation_options.output_type())); -- } -- -- auto segmenter = ImageSegmenter::CreateFromOptions(options); -- return core::get_value(segmenter); -- }) -- .def("segment", -- [](ImageSegmenter& self, -- const ImageData& image_data) -> SegmentationResult { -- auto frame_buffer = CreateFrameBufferFromImageData(image_data); -- auto vision_segmentation_result = -- self.Segment(*core::get_value(frame_buffer)); -- return core::get_value(vision_segmentation_result); -- }); --} -- --} // namespace vision --} // namespace task --} // namespace tflite -diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_object_detector.cc b/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_object_detector.cc -deleted file mode 100644 -index 36fa2372e60af..0000000000000 ---- a/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_object_detector.cc -+++ /dev/null -@@ -1,84 +0,0 @@ --/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. -- --Licensed under the Apache License, Version 2.0 (the "License"); --you may not use this file except in compliance with the License. --You may obtain a copy of the License at -- -- http://www.apache.org/licenses/LICENSE-2.0 -- --Unless required by applicable law or agreed to in writing, software --distributed under the License is distributed on an "AS IS" BASIS, --WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. --See the License for the specific language governing permissions and --limitations under the License. --==============================================================================*/ -- --#include "pybind11/pybind11.h" --#include "pybind11_protobuf/native_proto_caster.h" // from @pybind11_protobuf --#include "tensorflow_lite_support/cc/task/processor/proto/detection_options.pb.h" --#include "tensorflow_lite_support/cc/task/processor/proto/detections.pb.h" --#include "tensorflow_lite_support/cc/task/vision/object_detector.h" --#include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h" --#include "tensorflow_lite_support/python/task/core/pybinds/task_utils.h" -- --namespace tflite { --namespace task { --namespace vision { -- --namespace { --namespace py = ::pybind11; --using PythonBaseOptions = ::tflite::python::task::core::BaseOptions; --using CppBaseOptions = ::tflite::task::core::BaseOptions; --} // namespace -- --PYBIND11_MODULE(_pywrap_object_detector, m) { -- // python wrapper for C++ ObjectDetector class which shouldn't be directly -- // used by the users. -- pybind11_protobuf::ImportNativeProtoCasters(); -- -- py::class_<ObjectDetector>(m, "ObjectDetector") -- .def_static( -- "create_from_options", -- [](const PythonBaseOptions& base_options, -- const processor::DetectionOptions& detection_options) { -- ObjectDetectorOptions options; -- auto cpp_base_options = -- core::convert_to_cpp_base_options(base_options); -- options.set_allocated_base_options(cpp_base_options.release()); -- -- if (detection_options.has_display_names_locale()) { -- options.set_display_names_locale( -- detection_options.display_names_locale()); -- } -- if (detection_options.has_max_results()) { -- options.set_max_results(detection_options.max_results()); -- } -- if (detection_options.has_score_threshold()) { -- options.set_score_threshold(detection_options.score_threshold()); -- } -- options.mutable_class_name_whitelist()->CopyFrom( -- detection_options.class_name_allowlist()); -- options.mutable_class_name_blacklist()->CopyFrom( -- detection_options.class_name_denylist()); -- -- auto detector = ObjectDetector::CreateFromOptions(options); -- return core::get_value(detector); -- }) -- .def("detect", -- [](ObjectDetector& self, -- const ImageData& image_data) -> processor::DetectionResult { -- auto frame_buffer = CreateFrameBufferFromImageData(image_data); -- auto vision_detection_result = -- self.Detect(*core::get_value(frame_buffer)); -- // Convert from vision::DetectionResult to -- // processor::DetectionResult as required by the Python layer. -- processor::DetectionResult detection_result; -- detection_result.ParseFromString( -- core::get_value(vision_detection_result).SerializeAsString()); -- return detection_result; -- }); --} -- --} // namespace vision --} // namespace task --} // namespace tflite -diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/index_table_sum.h b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/index_table_sum.h -deleted file mode 100644 -index 67e0e303d4231..0000000000000 ---- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/index_table_sum.h -+++ /dev/null -@@ -1,256 +0,0 @@ --/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. -- --Licensed under the Apache License, Version 2.0 (the "License"); --you may not use this file except in compliance with the License. --You may obtain a copy of the License at -- -- http://www.apache.org/licenses/LICENSE-2.0 -- --Unless required by applicable law or agreed to in writing, software --distributed under the License is distributed on an "AS IS" BASIS, --WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. --See the License for the specific language governing permissions and --limitations under the License. --==============================================================================*/ --#ifndef TENSORFLOW_LITE_SUPPORT_SCANN_ONDEVICE_CC_CORE_INDEX_TABLE_SUM_H_ --#define TENSORFLOW_LITE_SUPPORT_SCANN_ONDEVICE_CC_CORE_INDEX_TABLE_SUM_H_ -- --#include <array> --#include <cstddef> --#include <cstdint> --#include <type_traits> --#include <vector> -- --#include "Eigen/Core" // from @eigen --#include "tensorflow_lite_support/scann_ondevice/cc/core/simd_utils.h" -- --namespace tflite { --namespace scann_ondevice { --namespace core { -- --template <typename LutType> --void RearrangeLUT(const LutType* input_data, -- int batch_elems, -- int batch_size, -- LutType* const output_data) { -- std::vector<int64_t> simd_sizes; -- if (std::is_same<LutType, float>::value) { --#ifdef __AVX__ -- simd_sizes = {8, 4}; --#elif defined __SSE__ -- simd_sizes = {4}; --#elif defined __ARM_NEON__ -- simd_sizes = {4}; --#endif -- } else { --#ifdef __AVX2__ -- simd_sizes = {16, 8}; --#elif defined __SSE4_1__ -- simd_sizes = {8}; --#elif defined __ARM_NEON__ -- simd_sizes = {8}; --#endif -- } -- -- int64_t offset = 0; -- for (int64_t simd_size : simd_sizes) { -- const int64_t num_simds = batch_size / simd_size; -- const int64_t simd_batch_elems = simd_size * batch_elems; -- for (; offset < num_simds * simd_batch_elems; offset += simd_batch_elems) { -- using RowMajorMatrix = Eigen::Matrix<LutType, Eigen::Dynamic, -- Eigen::Dynamic, Eigen::RowMajor>; -- Eigen::Map<const RowMajorMatrix> input_map(input_data + offset, simd_size, -- batch_elems); -- Eigen::Map<RowMajorMatrix> output_map(output_data + offset, batch_elems, -- simd_size); -- output_map = input_map.transpose(); -- } -- } -- std::copy(input_data + offset, input_data + batch_elems * batch_size, -- output_data + offset); --} --const int kDefaultChunksPerBlock = 32; --const int k16CentersUint8LutChunksPerBlock = 256; --const int kUnrollSteps = 6; -- --template <typename T> --struct MaxQuantizationValue { -- static_assert(std::is_same<T, float>::value, "Invalid lookup table type."); -- static constexpr size_t value = 0; --}; -- --template <> --struct MaxQuantizationValue<uint8_t> { -- static constexpr size_t value = 255; --}; -- --template <> --struct MaxQuantizationValue<uint16_t> { -- static constexpr size_t value = (1 << 16) / kDefaultChunksPerBlock - 1; --}; -- --template <typename SimdType, typename LutType, size_t NumCenters = 0> --size_t IndexTableSumSimdBatch(const uint8_t* indices, -- size_t num_chunks, -- size_t num_outputs, -- const LutType* lookup_table, -- size_t batch_size, -- size_t num_centers, -- float min, -- float max, -- size_t batch_index, -- float* const output) { -- if (num_centers == 256) { -- return IndexTableSumSimdBatch<SimdType, LutType, 256>( -- indices, num_chunks, num_outputs, lookup_table, batch_size, 0, min, max, -- batch_index, output); -- } -- const size_t lut_chunk_stride = NumCenters ? NumCenters * SimdType::size() -- : num_centers * SimdType::size(); -- const size_t lut_item_stride = -- NumCenters ? NumCenters * num_chunks : num_chunks * num_centers; -- constexpr bool must_dequantize = !std::is_same<LutType, float>::value; -- constexpr size_t max_qval = MaxQuantizationValue<LutType>::value; -- const float dq_scale = must_dequantize ? (max - min) / max_qval : 0.0f; -- const float dq_offset_1 = must_dequantize ? min + dq_scale / 2 : 0.0f; -- -- const size_t chunks_per_block = -- std::is_same<LutType, uint8_t>::value && -- (NumCenters ? NumCenters : num_centers) == 16 -- ? k16CentersUint8LutChunksPerBlock -- : kDefaultChunksPerBlock; -- -- for (; batch_index + SimdType::size() <= batch_size; -- batch_index += SimdType::size()) { -- const LutType* batch_lut = lookup_table + batch_index * lut_item_stride; -- float* const batch_output = output + batch_index; -- for (size_t block_start = 0; block_start < num_chunks; -- block_start += chunks_per_block) { -- const size_t block_end = -- std::min(block_start + chunks_per_block, num_chunks); -- const float dq_offset_n = (block_end - block_start) * dq_offset_1; -- size_t output_index; -- for (output_index = 0; output_index + kUnrollSteps <= num_outputs; -- output_index += kUnrollSteps) { -- const uint8_t* indices_base = indices + output_index * num_chunks; -- size_t chunk_index = block_start; -- const LutType* chunk_lut = batch_lut + chunk_index * lut_chunk_stride; -- std::array<SimdType, kUnrollSteps> accums; -- for (size_t i = 0; i < kUnrollSteps; ++i) { -- const size_t center_index = -- indices_base[i * num_chunks + chunk_index]; -- accums[i].load(chunk_lut + center_index * SimdType::size()); -- } -- ++chunk_index; -- chunk_lut += lut_chunk_stride; -- for (; chunk_index < block_end; ++chunk_index) { -- for (size_t i = 0; i < kUnrollSteps; ++i) { -- SimdType simd; -- const size_t center_index = -- indices_base[i * num_chunks + chunk_index]; -- simd.load(chunk_lut + center_index * SimdType::size()); -- accums[i] += simd; -- } -- chunk_lut += lut_chunk_stride; -- } -- for (size_t i = 0; i < kUnrollSteps; ++i) { -- accums[i].dequantize_accum_storeu( -- batch_output + (output_index + i) * batch_size, dq_scale, -- dq_offset_n); -- } -- } -- for (; output_index < num_outputs; ++output_index) { -- const uint8_t* vector_indices = indices + output_index * num_chunks; -- -- SimdType accum; -- accum.setzero(); -- size_t chunk_index = block_start; -- const LutType* chunk_lut = batch_lut + chunk_index * lut_chunk_stride; -- for (; chunk_index < block_end; ++chunk_index) { -- SimdType simd; -- simd.load(chunk_lut + vector_indices[chunk_index] * SimdType::size()); -- accum += simd; -- chunk_lut += lut_chunk_stride; -- } -- -- accum.dequantize_accum_storeu(batch_output + output_index * batch_size, -- dq_scale, dq_offset_n); -- } -- } -- } -- -- return batch_index; --} -- --template <typename LutType> --void IndexTableSum(const uint8_t* indices, -- size_t num_chunks, -- size_t num_outputs, -- const LutType* lookup_table, -- size_t batch_size, -- size_t num_centers, -- float min, -- float max, -- float* const output) { -- static_assert(std::is_same<LutType, uint8_t>::value || -- std::is_same<LutType, uint16_t>::value, -- "Invalid lookup table type."); -- std::fill(output, output + batch_size * num_outputs, 0.0f); -- size_t i = 0; --#ifdef __AVX2__ -- i = IndexTableSumSimdBatch<SimdInt16x16, LutType>( -- indices, num_chunks, num_outputs, lookup_table, batch_size, num_centers, -- min, max, i, output); --#endif --#ifdef __SSE4_1__ -- i = IndexTableSumSimdBatch<SimdInt16x8, LutType>( -- indices, num_chunks, num_outputs, lookup_table, batch_size, num_centers, -- min, max, i, output); --#endif --#ifdef __ARM_NEON__ -- i = IndexTableSumSimdBatch<SimdInt16x8, LutType>( -- indices, num_chunks, num_outputs, lookup_table, batch_size, num_centers, -- min, max, i, output); --#endif -- i = IndexTableSumSimdBatch<SimdInt16x1, LutType>( -- indices, num_chunks, num_outputs, lookup_table, batch_size, num_centers, -- min, max, i, output); --} -- --template <> --inline void IndexTableSum<float>(const uint8_t* indices, -- size_t num_chunks, -- size_t num_outputs, -- const float* lookup_table, -- size_t batch_size, -- size_t num_centers, -- float min, -- float max, -- float* const output) { -- std::fill(output, output + batch_size * num_outputs, 0.0f); -- size_t i = 0; --#ifdef __AVX__ -- i = IndexTableSumSimdBatch<SimdFloat32x8, float>( -- indices, num_chunks, num_outputs, lookup_table, batch_size, num_centers, -- min, max, i, output); --#endif --#ifdef __SSE__ -- i = IndexTableSumSimdBatch<SimdFloat32x4, float>( -- indices, num_chunks, num_outputs, lookup_table, batch_size, num_centers, -- min, max, i, output); --#endif --#ifdef __ARM_NEON__ -- i = IndexTableSumSimdBatch<SimdFloat32x4, float>( -- indices, num_chunks, num_outputs, lookup_table, batch_size, num_centers, -- min, max, i, output); --#endif -- i = IndexTableSumSimdBatch<SimdFloat32x1, float>( -- indices, num_chunks, num_outputs, lookup_table, batch_size, num_centers, -- min, max, i, output); --} -- --} // namespace core --} // namespace scann_ondevice --} // namespace tflite -- --#endif // TENSORFLOW_LITE_SUPPORT_SCANN_ONDEVICE_CC_CORE_INDEX_TABLE_SUM_H_ -diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/partitioner.h b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/partitioner.h -deleted file mode 100644 -index f4e9eb9e34804..0000000000000 ---- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/partitioner.h -+++ /dev/null -@@ -1,76 +0,0 @@ --/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. -- --Licensed under the Apache License, Version 2.0 (the "License"); --you may not use this file except in compliance with the License. --You may obtain a copy of the License at -- -- http://www.apache.org/licenses/LICENSE-2.0 -- --Unless required by applicable law or agreed to in writing, software --distributed under the License is distributed on an "AS IS" BASIS, --WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. --See the License for the specific language governing permissions and --limitations under the License. --==============================================================================*/ --#ifndef TENSORFLOW_LITE_SUPPORT_SCANN_ONDEVICE_CC_CORE_PARTITIONER_H_ --#define TENSORFLOW_LITE_SUPPORT_SCANN_ONDEVICE_CC_CORE_PARTITIONER_H_ -- --#include <utility> -- --#include "Eigen/Core" // from @eigen --#include "absl/types/optional.h" // from @com_google_absl --#include "tensorflow_lite_support/scann_ondevice/cc/core/serialized_searcher.pb.h" -- --namespace tflite { --namespace scann_ondevice { --namespace core { --class PartitionerInterface { -- public: -- virtual ~PartitionerInterface() {} -- virtual bool Partition(const Eigen::Ref<const Eigen::MatrixXf>& queries, -- std::vector<std::vector<int>>* tokens) const = 0; -- -- virtual int NumPartitions() const = 0; -- virtual absl::optional<int> get_vector_dimension() const = 0; --}; --class Partitioner : public PartitionerInterface { -- public: -- static std::unique_ptr<Partitioner> Create(const PartitionerProto& proto); -- bool Partition(const Eigen::Ref<const Eigen::MatrixXf>& queries, -- std::vector<std::vector<int>>* tokens) const override; -- int NumPartitions() const override; -- -- inline absl::optional<int> get_vector_dimension() const override { -- return leaves_.cols(); -- } -- -- private: -- Partitioner(Eigen::MatrixXf leaves, -- Eigen::VectorXf leaf_norms, -- DistanceMeasure distance) -- : leaves_(std::move(leaves)), -- leaf_norms_(std::move(leaf_norms)), -- distance_(distance) {} -- -- Eigen::MatrixXf leaves_; -- Eigen::VectorXf leaf_norms_; -- DistanceMeasure distance_; --}; --class NoOpPartitioner : public PartitionerInterface { -- public: -- ~NoOpPartitioner() override {} -- -- bool Partition(const Eigen::Ref<const Eigen::MatrixXf>& queries, -- std::vector<std::vector<int>>* tokens) const override; -- -- int NumPartitions() const override; -- inline absl::optional<int> get_vector_dimension() const override { -- return absl::optional<int>(); -- } --}; -- --} // namespace core --} // namespace scann_ondevice --} // namespace tflite -- --#endif // TENSORFLOW_LITE_SUPPORT_SCANN_ONDEVICE_CC_CORE_PARTITIONER_H_ -diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/processor.h b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/processor.h -deleted file mode 100644 -index 97206f4ba1aa6..0000000000000 ---- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/processor.h -+++ /dev/null -@@ -1,101 +0,0 @@ --/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. -- --Licensed under the Apache License, Version 2.0 (the "License"); --you may not use this file except in compliance with the License. --You may obtain a copy of the License at -- -- http://www.apache.org/licenses/LICENSE-2.0 -- --Unless required by applicable law or agreed to in writing, software --distributed under the License is distributed on an "AS IS" BASIS, --WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. --See the License for the specific language governing permissions and --limitations under the License. --==============================================================================*/ --#ifndef TENSORFLOW_LITE_SUPPORT_SCANN_ONDEVICE_CC_CORE_PROCESSOR_H_ --#define TENSORFLOW_LITE_SUPPORT_SCANN_ONDEVICE_CC_CORE_PROCESSOR_H_ -- --#include <cstdint> --#include <utility> --#include <vector> -- --#include "Eigen/Core" // from @eigen --#include "tensorflow_lite_support/scann_ondevice/cc/core/serialized_searcher.pb.h" --#include "tensorflow_lite_support/scann_ondevice/cc/core/top_n_amortized_constant.h" -- --namespace tflite { --namespace scann_ondevice { --namespace core { --struct QueryInfo { -- template <typename T> -- using Matrix = -- Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>; -- -- float fixed_point_min = NAN; -- float fixed_point_max = NAN; -- float fixed_point_offset = NAN; -- float fixed_point_scale = NAN; -- -- std::shared_ptr<Matrix<float>> query_lut; -- std::shared_ptr<Matrix<uint16_t>> query_lut_uint16; -- std::shared_ptr<Matrix<uint8_t>> query_lut_uint8; -- template <typename T> -- std::shared_ptr<Matrix<T>> QueryLUT(); -- -- std::shared_ptr<Matrix<float>> transposed_query_lut; -- std::shared_ptr<Matrix<uint16_t>> transposed_query_lut_uint16; -- std::shared_ptr<Matrix<uint8_t>> transposed_query_lut_uint8; -- template <typename T> -- std::shared_ptr<Matrix<T>> TransposedQueryLUT(); --}; --class PreProcessorInterface { -- public: -- virtual ~PreProcessorInterface() {} -- -- virtual bool Process(const Eigen::Ref<const Eigen::MatrixXf>& queries, -- QueryInfo* query_info) const = 0; -- virtual int num_database_dims() const = 0; -- virtual int num_query_dims() const = 0; --}; --class PostProcessorInterface { -- public: -- virtual ~PostProcessorInterface() {} -- -- virtual bool Process(std::vector<TopN>* top_n) const = 0; --}; --class AsymmetricHashQuerier : public PreProcessorInterface { -- public: -- static std::unique_ptr<AsymmetricHashQuerier> Create( -- const AsymmetricHashingProto& proto); -- bool Process(const Eigen::Ref<const Eigen::MatrixXf>& queries, -- QueryInfo* query_info) const override; -- -- inline int num_database_dims() const override { return codebooks_.size(); } -- -- inline int num_query_dims() const override { return dims_; } -- -- private: -- AsymmetricHashQuerier(std::vector<Eigen::MatrixXf> codebooks, -- std::vector<Eigen::VectorXf> codebook_norms, -- DistanceMeasure query_distance, -- AsymmetricHashingProto::LookupType lookup_type, -- int dims) -- : dims_(dims), -- lookup_type_(lookup_type), -- query_distance_(query_distance), -- codebooks_(std::move(codebooks)), -- codebook_norms_(std::move(codebook_norms)) {} -- void RearrangeLUT(QueryInfo* query_info) const; -- -- int dims_; -- AsymmetricHashingProto::LookupType lookup_type_; -- DistanceMeasure query_distance_; -- std::vector<Eigen::MatrixXf> codebooks_; -- std::vector<Eigen::VectorXf> codebook_norms_; --}; -- --} // namespace core --} // namespace scann_ondevice --} // namespace tflite -- --#endif // TENSORFLOW_LITE_SUPPORT_SCANN_ONDEVICE_CC_CORE_PROCESSOR_H_ -diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/searcher.h b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/searcher.h -deleted file mode 100644 -index 419681b829b1d..0000000000000 ---- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/searcher.h -+++ /dev/null -@@ -1,256 +0,0 @@ --/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. -- --Licensed under the Apache License, Version 2.0 (the "License"); --you may not use this file except in compliance with the License. --You may obtain a copy of the License at -- -- http://www.apache.org/licenses/LICENSE-2.0 -- --Unless required by applicable law or agreed to in writing, software --distributed under the License is distributed on an "AS IS" BASIS, --WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. --See the License for the specific language governing permissions and --limitations under the License. --==============================================================================*/ --#ifndef TENSORFLOW_LITE_SUPPORT_SCANN_ONDEVICE_CC_CORE_SEARCHER_H_ --#define TENSORFLOW_LITE_SUPPORT_SCANN_ONDEVICE_CC_CORE_SEARCHER_H_ -- --#include <algorithm> --#include <cstdint> --#include <limits> --#include <utility> --#include <vector> -- --#include <glog/logging.h> --#include "Eigen/Core" // from @eigen --#include "absl/types/span.h" // from @com_google_absl --#include "tensorflow_lite_support/cc/port/integral_types.h" --#include "tensorflow_lite_support/scann_ondevice/cc/core/index_table_sum.h" --#include "tensorflow_lite_support/scann_ondevice/cc/core/processor.h" --#include "tensorflow_lite_support/scann_ondevice/cc/core/serialized_searcher.pb.h" --#include "tensorflow_lite_support/scann_ondevice/cc/core/simd_utils.h" --#include "tensorflow_lite_support/scann_ondevice/cc/core/top_n_amortized_constant.h" -- --namespace tflite { --namespace scann_ondevice { --namespace core { -- --using Matrix8u = -- Eigen::Matrix<uint8_t, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>; -- --namespace internal { --void ComputeAHDistance(const QueryInfo& query_info, -- Eigen::Ref<const Matrix8u> database, -- Eigen::Ref<Eigen::MatrixXf> output); -- --} --template <class T> --bool AsymmetricHashFindNeighbors(const QueryInfo& query_info, -- Eigen::Ref<const Matrix8u> database, -- size_t global_offset, -- absl::Span<T> topn) { -- const int batch_size = query_info.query_lut->cols(); -- if (topn.size() != batch_size) { -- return false; -- } -- int database_size = database.cols(); -- Eigen::MatrixXf output(batch_size, database_size); -- internal::ComputeAHDistance(query_info, database, output); -- -- for (int i = 0; i < database_size; i++) { -- for (int j = 0; j < topn.size(); ++j) { -- topn[j].emplace(output(j, i), i + global_offset); -- } -- } -- return true; --} --template <class T> --bool AsymmetricHashFindNeighbors(Eigen::Ref<const Eigen::MatrixXf> queries, -- const PreProcessorInterface& preprocessor, -- Eigen::Ref<const Matrix8u> database, -- size_t global_offset, -- absl::Span<T> topn) { -- if (queries.cols() != topn.size()) { -- return false; -- } -- QueryInfo query_info; -- return preprocessor.Process(queries, &query_info) && -- AsymmetricHashFindNeighbors(query_info, database, global_offset, topn); --} --template <class T> --bool FloatFindNeighbors(Eigen::Ref<const Eigen::MatrixXf> queries, -- Eigen::Ref<const Eigen::MatrixXf> database, -- const size_t global_offset, -- const DistanceMeasure distance_measure, -- absl::Span<T> topn) { -- int query_size = queries.cols(); -- int database_size = database.cols(); -- Eigen::MatrixXf pairwise_distances(query_size, database_size); -- -- if (distance_measure == SQUARED_L2_DISTANCE) { -- pairwise_distances.colwise() = queries.colwise().squaredNorm().transpose(); -- pairwise_distances.rowwise() += database.colwise().squaredNorm(); -- pairwise_distances -= 2 * queries.transpose() * database; -- } else if (distance_measure == DOT_PRODUCT) { -- pairwise_distances = -1 * queries.transpose() * database; -- } else { -- LOG(ERROR) << "Unsupported distance measure: " -- << DistanceMeasure_Name(distance_measure); -- return false; -- } -- -- for (int i = 0; i < database_size; ++i) { -- for (int j = 0; j < query_size; ++j) { -- topn[j].emplace(pairwise_distances(j, i), i + global_offset); -- } -- } -- return true; --} --template <class T> --class SearcherInterfaceT { -- public: -- virtual ~SearcherInterfaceT() {} -- -- virtual bool FindNeighbors(const Eigen::Ref<const Eigen::MatrixXf>& queries, -- std::vector<T>* topn) const = 0; --}; --template <class T> --class AsymmetricHashLeafSearcherT : public SearcherInterfaceT<T> { -- public: -- static std::unique_ptr<AsymmetricHashLeafSearcherT<T>> Create( -- std::shared_ptr<QueryInfo::Matrix<uint8_t>> database, -- int global_offset, -- std::shared_ptr<PreProcessorInterface> preprocessor); -- static std::unique_ptr<AsymmetricHashLeafSearcherT<T>> Create( -- std::shared_ptr<QueryInfo::Matrix<uint8_t>> database, -- int global_offset, -- std::shared_ptr<PreProcessorInterface> preprocessor, -- size_t mini_batch_size); -- bool FindNeighbors(const Eigen::Ref<const Eigen::MatrixXf>& queries, -- std::vector<T>* topn) const override; -- bool FindNeighbors(const QueryInfo& query_info, std::vector<T>* topn) const; -- -- private: -- AsymmetricHashLeafSearcherT( -- std::shared_ptr<QueryInfo::Matrix<uint8_t>> database, -- int global_offset, -- std::shared_ptr<PreProcessorInterface> preprocessor, -- size_t mini_batch_size) -- : database_(std::move(database)), -- global_offset_(global_offset), -- preprocessor_(std::move(preprocessor)), -- mini_batch_size_(mini_batch_size) {} -- std::shared_ptr<QueryInfo::Matrix<uint8_t>> database_ = nullptr; -- int global_offset_; -- std::shared_ptr<PreProcessorInterface> preprocessor_ = nullptr; -- const size_t mini_batch_size_; --}; --template <class T> --class LinearLeafSearcherT : public SearcherInterfaceT<T> { -- public: -- ~LinearLeafSearcherT() override {} -- static std::unique_ptr<LinearLeafSearcherT<T>> Create( -- std::shared_ptr<Eigen::MatrixXf> database, -- DistanceMeasure distance_measure = SQUARED_L2_DISTANCE, -- int global_offset = 0); -- -- bool FindNeighbors(const Eigen::Ref<const Eigen::MatrixXf>& queries, -- std::vector<T>* topn) const override; -- -- private: -- LinearLeafSearcherT(std::shared_ptr<Eigen::MatrixXf> database, -- DistanceMeasure distance_measure, -- int global_offset) -- : database_(std::move(database)), -- distance_measure_(distance_measure), -- global_offset_(global_offset) {} -- -- std::shared_ptr<Eigen::MatrixXf> database_ = nullptr; -- const DistanceMeasure distance_measure_; -- int global_offset_; --}; -- --template <class T> --std::unique_ptr<AsymmetricHashLeafSearcherT<T>> --AsymmetricHashLeafSearcherT<T>::Create( -- std::shared_ptr<Matrix8u> database, -- int global_offset, -- std::shared_ptr<PreProcessorInterface> preprocessor) { -- return AsymmetricHashLeafSearcherT<T>::Create( -- database, global_offset, preprocessor, -- std::numeric_limits<size_t>::max()); --} -- --template <class T> --std::unique_ptr<AsymmetricHashLeafSearcherT<T>> --AsymmetricHashLeafSearcherT<T>::Create( -- std::shared_ptr<Matrix8u> database, -- int global_offset, -- std::shared_ptr<PreProcessorInterface> preprocessor, -- size_t mini_batch_size) { -- if (mini_batch_size == 0 || global_offset < 0) { -- return nullptr; -- } -- return std::unique_ptr<AsymmetricHashLeafSearcherT<T>>( -- new AsymmetricHashLeafSearcherT<T>(std::move(database), global_offset, -- std::move(preprocessor), -- mini_batch_size)); --} -- --template <class T> --bool AsymmetricHashLeafSearcherT<T>::FindNeighbors( -- const Eigen::Ref<const Eigen::MatrixXf>& queries, -- std::vector<T>* topn) const { -- if (queries.cols() != topn->size()) { -- return false; -- } -- -- absl::Span<T> topn_span = absl::MakeSpan(*topn); -- for (size_t i = 0; i < queries.cols(); i += mini_batch_size_) { -- const size_t num_queries_in_batch = -- std::min(mini_batch_size_, queries.cols() - i); -- if (!AsymmetricHashFindNeighbors<T>( -- queries.middleCols(i, num_queries_in_batch), *preprocessor_, -- *database_, global_offset_, -- topn_span.subspan(i, num_queries_in_batch))) { -- return false; -- } -- } -- return true; --} -- --template <class T> --bool AsymmetricHashLeafSearcherT<T>::FindNeighbors(const QueryInfo& query_info, -- std::vector<T>* topn) const { -- return AsymmetricHashFindNeighbors<T>(query_info, *database_, global_offset_, -- absl::MakeSpan(*topn)); --} -- --template <class T> --std::unique_ptr<LinearLeafSearcherT<T>> LinearLeafSearcherT<T>::Create( -- std::shared_ptr<Eigen::MatrixXf> database, -- DistanceMeasure distance_measure, -- int global_offset) { -- if (global_offset < 0) { -- return nullptr; -- } -- return std::unique_ptr<LinearLeafSearcherT<T>>(new LinearLeafSearcherT<T>( -- std::move(database), distance_measure, global_offset)); --} -- --template <class T> --bool LinearLeafSearcherT<T>::FindNeighbors( -- const Eigen::Ref<const Eigen::MatrixXf>& queries, -- std::vector<T>* topn) const { -- return FloatFindNeighbors<T>(queries, *database_, global_offset_, -- distance_measure_, absl::MakeSpan(*topn)); --} -- --using SearcherInterface = SearcherInterfaceT<TopN>; --using AsymmetricHashLeafSearcher = AsymmetricHashLeafSearcherT<TopN>; --using LinearLeafSearcher = LinearLeafSearcherT<TopN>; --} // namespace core --} // namespace scann_ondevice --} // namespace tflite -- --#endif // TENSORFLOW_LITE_SUPPORT_SCANN_ONDEVICE_CC_CORE_SEARCHER_H_ -diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/searcher_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/searcher_test.cc -deleted file mode 100644 -index f3931f3619b8d..0000000000000 ---- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/searcher_test.cc -+++ /dev/null -@@ -1,532 +0,0 @@ --/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. -- --Licensed under the Apache License, Version 2.0 (the "License"); --you may not use this file except in compliance with the License. --You may obtain a copy of the License at -- -- http://www.apache.org/licenses/LICENSE-2.0 -- --Unless required by applicable law or agreed to in writing, software --distributed under the License is distributed on an "AS IS" BASIS, --WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. --See the License for the specific language governing permissions and --limitations under the License. --==============================================================================*/ --#include "tensorflow_lite_support/scann_ondevice/cc/core/searcher.h" -- --#include <algorithm> --#include <cstdint> --#include <limits> --#include <memory> --#include <utility> -- --#include <glog/logging.h> --#include "Eigen/Core" // from @eigen --#include "absl/synchronization/mutex.h" // from @com_google_absl --#include "tensorflow_lite_support/cc/port/gmock.h" --#include "tensorflow_lite_support/cc/port/gtest.h" --#include "tensorflow_lite_support/cc/port/integral_types.h" --#include "tensorflow_lite_support/cc/port/proto2.h" --#include "tensorflow_lite_support/scann_ondevice/cc/core/partitioner.h" --#include "tensorflow_lite_support/scann_ondevice/cc/core/processor.h" --#include "tensorflow_lite_support/scann_ondevice/cc/core/serialized_searcher.pb.h" --using TextFormat = ::tflite::support::proto::TextFormat; -- --using Eigen::MatrixXf; --using ::testing::ElementsAre; --using ::testing::Pair; --using ::testing::TestWithParam; --using ::testing::Values; --using Matrix8u = -- Eigen::Matrix<uint8_t, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>; --using tflite::scann_ondevice::core::TopN; -- --const char kExampleAsymmetricHashingProtoString[] = -- R"( -- subspace: { -- entry { -- dimension: 0.1; -- dimension: 0.2; -- } -- entry: { -- dimension: 0.2; -- dimension: 0.1; -- } -- entry: { -- dimension: 0.9; -- dimension: 0.8; -- } -- } -- subspace: { -- entry { -- dimension: -0.1; -- dimension: -0.2; -- dimension: -0.3; -- } -- entry: { -- dimension: -0.3; -- dimension: -0.2; -- dimension: -0.1; -- } -- entry: { -- dimension: -0.9; -- dimension: -0.8; -- dimension: -0.7; -- } -- })"; -- --const char kExamplePartitionerProtoString[] = -- R"( -- leaf: { -- dimension: 0.1; -- dimension: 0.2; -- } -- leaf: { -- dimension: 0.2; -- dimension: 0.1; -- } -- leaf: { -- dimension: 0.9; -- dimension: 0.7; -- } -- leaf: { -- dimension: 0.3; -- dimension: 0.3; -- })"; --namespace tflite { --namespace scann_ondevice { --namespace core { --namespace { --TEST(PartitionerTest, Partition) { -- PartitionerProto proto; -- TextFormat::ParseFromString(kExamplePartitionerProtoString, &proto); -- proto.set_query_distance(SQUARED_L2_DISTANCE); -- auto partitioner = Partitioner::Create(proto); -- MatrixXf query(2, 3); -- query << 0.3, 0.9, -1, 0.2, 0.9, -1; -- -- std::vector<std::vector<int>> tokens(3, std::vector<int>(2, -1)); -- ASSERT_TRUE(partitioner->Partition(query, &tokens)); -- for (int i = 0; i < 3; ++i) { -- std::sort(tokens[i].begin(), tokens[i].end()); -- } -- EXPECT_EQ((std::vector<int>{1, 3}), tokens[0]); -- EXPECT_EQ((std::vector<int>{2, 3}), tokens[1]); -- EXPECT_EQ((std::vector<int>{0, 1}), tokens[2]); --} -- --TEST(PartitionerTest, PartitionDotProductDistance) { -- PartitionerProto proto; -- TextFormat::ParseFromString(kExamplePartitionerProtoString, &proto); -- proto.set_query_distance(DOT_PRODUCT); -- auto partitioner = Partitioner::Create(proto); -- MatrixXf query(2, 3); -- query << 0.3, 0.9, -1, 0.2, 0.9, -1; -- -- std::vector<std::vector<int>> tokens(3, std::vector<int>(2, -1)); -- ASSERT_TRUE(partitioner->Partition(query, &tokens)); -- for (int i = 0; i < 3; ++i) { -- std::sort(tokens[i].begin(), tokens[i].end()); -- } -- EXPECT_EQ((std::vector<int>{2, 3}), tokens[0]); -- EXPECT_EQ((std::vector<int>{2, 3}), tokens[1]); -- EXPECT_EQ((std::vector<int>{0, 1}), tokens[2]); --} -- --TEST(ProcessorTest, AsymmetricHashQuerierNonSimd) { -- AsymmetricHashingProto proto; -- TextFormat::ParseFromString(kExampleAsymmetricHashingProtoString, &proto); -- auto querier = AsymmetricHashQuerier::Create(proto); -- CHECK(querier); -- -- MatrixXf query(5, 2); -- query << 0, 1, 0, 1, 0, 1, 0, 1, 0, 1; -- QueryInfo query_info; -- ASSERT_TRUE(querier->Process(query, &query_info)); -- MatrixXf expected_lut(6, 2); -- expected_lut << 0.05, 1.45, 0.05, 1.45, 1.45, 0.05, 0.14, 4.34, 0.14, 4.34, -- 1.94, 9.74; -- ASSERT_TRUE(query_info.query_lut->isApprox(expected_lut, 1e-5)); --} -- --TEST(ProcessorTest, AsymmetricHashQuerierNonSimdDotProduct) { -- AsymmetricHashingProto proto; -- TextFormat::ParseFromString(kExampleAsymmetricHashingProtoString, &proto); -- proto.set_query_distance(DOT_PRODUCT); -- auto querier = AsymmetricHashQuerier::Create(proto); -- ASSERT_NE(querier, nullptr); -- -- MatrixXf query(5, 2); -- query << 0, 1, 0, 1, 0, 1, 0, 1, 0, 1; -- QueryInfo query_info; -- ASSERT_TRUE(querier->Process(query, &query_info)); -- -- const auto& query_lut = query_info.query_lut; -- const float* lut_raw = query_lut->data(); -- EXPECT_THAT(std::vector<float>(lut_raw, lut_raw + query_lut->rows()), -- ElementsAre(0, 0, 0, 0, 0, 0)); -- EXPECT_THAT(std::vector<float>(lut_raw + query_lut->rows(), -- lut_raw + query_lut->rows() * 2), -- ElementsAre(-0.3, -0.3, -1.7, 0.6, 0.6, 2.4)); --} -- --TEST(ProcessorTest, AsymmetricHashQuerierSimd) { -- AsymmetricHashingProto proto; -- TextFormat::ParseFromString(kExampleAsymmetricHashingProtoString, &proto); -- auto querier = AsymmetricHashQuerier::Create(proto); -- MatrixXf query(5, 6); -- query << 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, -- 1, 0, 1, 0, 0, 1, 1; -- QueryInfo query_info; -- ASSERT_TRUE(querier->Process(query, &query_info)); -- MatrixXf expected_lut(6, 6); -- expected_lut << 0.05, 0.05, 1.45, 0.65, 0.85, 1.45, 0.05, 0.05, 1.45, 0.85, -- 0.65, 1.45, 1.45, 1.45, 0.05, 0.85, 0.65, 0.05, 0.14, 4.34, 0.14, 1.54, -- 2.94, 4.34, 0.14, 4.34, 0.14, 1.54, 2.94, 4.34, 1.94, 9.74, 1.94, 4.54, -- 7.14, 9.74; -- ASSERT_TRUE(query_info.query_lut->isApprox(expected_lut, 1e-5)); -- expected_lut << 0.05, 1.45, 0.14, 0.14, 0.85, 1.45, 0.05, 0.85, 4.34, 1.54, -- 0.65, 1.45, 1.45, 1.45, 0.14, 1.94, 0.65, 0.05, 0.65, 1.45, 1.54, 9.74, -- 2.94, 4.34, 0.05, 0.05, 0.14, 1.94, 2.94, 4.34, 0.05, 0.85, 4.34, 4.54, -- 7.14, 9.74; -- ASSERT_TRUE(query_info.transposed_query_lut->isApprox(expected_lut, 1e-5)); --} -- --TEST(ProcessorTest, AsymmetricHashPreprocessingLazyMemoryAllocation) { -- AsymmetricHashingProto proto; -- TextFormat::ParseFromString(kExampleAsymmetricHashingProtoString, &proto); -- auto querier = AsymmetricHashQuerier::Create(proto); -- QueryInfo query_info; -- { -- MatrixXf query(5, 2); -- query << 0, 0, 0, 0, 0, 1, 0, 1, 0, 1; -- ASSERT_TRUE(querier->Process(query, &query_info)); -- MatrixXf expected_lut(6, 2); -- expected_lut << 0.05, 0.05, 0.05, 0.05, 1.45, 1.45, 0.14, 4.34, 0.14, 4.34, -- 1.94, 9.74; -- EXPECT_TRUE(query_info.query_lut->isApprox(expected_lut, 1e-5)); -- EXPECT_TRUE(query_info.transposed_query_lut->isApprox(expected_lut, 1e-5)); -- } -- { -- MatrixXf query(5, 6); -- query << 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, -- 0, 1, 0, 1, 0, 0, 1, 1; -- ASSERT_TRUE(querier->Process(query, &query_info)); -- MatrixXf expected_lut(6, 6); -- expected_lut << 0.05, 0.05, 1.45, 0.65, 0.85, 1.45, 0.05, 0.05, 1.45, 0.85, -- 0.65, 1.45, 1.45, 1.45, 0.05, 0.85, 0.65, 0.05, 0.14, 4.34, 0.14, 1.54, -- 2.94, 4.34, 0.14, 4.34, 0.14, 1.54, 2.94, 4.34, 1.94, 9.74, 1.94, 4.54, -- 7.14, 9.74; -- EXPECT_TRUE(query_info.query_lut->isApprox(expected_lut, 1e-5)); -- expected_lut << 0.05, 1.45, 0.14, 0.14, 0.85, 1.45, 0.05, 0.85, 4.34, 1.54, -- 0.65, 1.45, 1.45, 1.45, 0.14, 1.94, 0.65, 0.05, 0.65, 1.45, 1.54, 9.74, -- 2.94, 4.34, 0.05, 0.05, 0.14, 1.94, 2.94, 4.34, 0.05, 0.85, 4.34, 4.54, -- 7.14, 9.74; -- EXPECT_TRUE(query_info.transposed_query_lut->isApprox(expected_lut, 1e-5)); -- } -- { -- MatrixXf query(5, 4); -- query << 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1; -- ASSERT_TRUE(querier->Process(query, &query_info)); -- MatrixXf expected_lut(6, 6); -- expected_lut << 1.45, 0.65, 0.85, 1.45, 0.85, 1.45, 1.45, 0.85, 0.65, 1.45, -- 0.65, 1.45, 0.05, 0.85, 0.65, 0.05, 0.65, 0.05, 0.14, 1.54, 2.94, 4.34, -- 2.94, 4.34, 0.14, 1.54, 2.94, 4.34, 2.94, 4.34, 1.94, 4.54, 7.14, 9.74, -- 7.14, 9.74; -- EXPECT_TRUE(query_info.query_lut->isApprox(expected_lut, 1e-5)); -- expected_lut << 1.45, 0.65, 0.14, 2.94, 0.85, 1.45, 0.65, 1.45, 1.54, 4.34, -- 0.65, 1.45, 0.85, 0.05, 2.94, 1.94, 0.65, 0.05, 1.45, 0.85, 4.34, 4.54, -- 2.94, 4.34, 1.45, 0.65, 0.14, 7.14, 2.94, 4.34, 0.85, 0.05, 1.54, 9.74, -- 7.14, 9.74; -- EXPECT_TRUE(query_info.transposed_query_lut->isApprox(expected_lut, 1e-5)); -- } --} -- --TEST(ProcessorTest, AsymmetricHashQuerierUint16) { -- AsymmetricHashingProto proto; -- TextFormat::ParseFromString(kExampleAsymmetricHashingProtoString, &proto); -- proto.set_lookup_type(AsymmetricHashingProto::INT16); -- auto querier = AsymmetricHashQuerier::Create(proto); -- CHECK(querier); -- -- MatrixXf query(5, 2); -- query << 0, 1, 0, 1, 0, 1, 0, 1, 0, 1; -- QueryInfo query_info; -- ASSERT_TRUE(querier->Process(query, &query_info)); -- QueryInfo::Matrix<uint16_t> expected_lut(6, 2); -- expected_lut << 0, 295, 0, 295, 295, 0, 19, 906, 19, 906, 399, 2047; -- -- LOG(INFO) << *(query_info.query_lut_uint16); -- -- ASSERT_EQ(*(query_info.query_lut_uint16), expected_lut); -- EXPECT_NEAR(query_info.fixed_point_min, 0.05, 1e-4); -- EXPECT_NEAR(query_info.fixed_point_max, 9.74, 1e-4); --} -- --TEST(ProcessorTest, AsymmetricHashQuerierUint8) { -- AsymmetricHashingProto proto; -- TextFormat::ParseFromString(kExampleAsymmetricHashingProtoString, &proto); -- proto.set_lookup_type(AsymmetricHashingProto::INT8); -- auto querier = AsymmetricHashQuerier::Create(proto); -- CHECK(querier); -- -- MatrixXf query(5, 2); -- query << 0, 1, 0, 1, 0, 1, 0, 1, 0, 1; -- QueryInfo query_info; -- ASSERT_TRUE(querier->Process(query, &query_info)); -- QueryInfo::Matrix<uint8_t> expected_lut(6, 2); -- expected_lut << 0, 36, 0, 36, 36, 0, 2, 112, 2, 112, 49, 255; -- ASSERT_EQ(*(query_info.query_lut_uint8), expected_lut); -- EXPECT_NEAR(query_info.fixed_point_min, 0.05, 1e-4); -- EXPECT_NEAR(query_info.fixed_point_max, 9.74, 1e-4); --} -- --class SearcherTest : public TestWithParam<size_t> {}; -- --TEST_P(SearcherTest, LinearLeafSearcherNonSimd) { -- MatrixXf query(3, 2); -- query << 0, 1, 2, 3, 3, 1; -- std::shared_ptr<MatrixXf> database(new MatrixXf(3, 5)); -- *database << 0, 1, 2, 2, 1, 1, 0, 1, 2, 2, 2, 2, 5, 6, 1; -- std::vector<TopN> top_n; -- for (int i = 0; i < 2; ++i) { -- top_n.emplace_back( -- TopN(3, std::make_pair(std::numeric_limits<float>::max(), -1))); -- } -- auto leaf_searcher = LinearLeafSearcher::Create(database); -- ASSERT_TRUE(leaf_searcher->FindNeighbors(query, &top_n)); -- -- constexpr float kEps = 1e-5; -- auto extracted = top_n[0].Take(); -- EXPECT_NEAR(2.0, extracted[0].first, kEps); -- EXPECT_NEAR(5.0, extracted[1].first, kEps); -- EXPECT_NEAR(6.0, extracted[2].first, kEps); -- EXPECT_EQ(0, extracted[0].second); -- EXPECT_EQ(4, extracted[1].second); -- EXPECT_EQ(1, extracted[2].second); -- -- extracted = top_n[1].Take(); -- EXPECT_NEAR(1.0, extracted[0].first, kEps); -- EXPECT_NEAR(6.0, extracted[1].first, kEps); -- EXPECT_NEAR(10.0, extracted[2].first, kEps); -- EXPECT_EQ(4, extracted[0].second); -- EXPECT_EQ(0, extracted[1].second); -- EXPECT_EQ(1, extracted[2].second); --} -- --TEST_P(SearcherTest, LinearLeafSearcherNonSimdDotProduct) { -- MatrixXf query(3, 2); -- query << 0, 1, 2, 3, 3, 1; -- auto database = std::make_shared<MatrixXf>(3, 5); -- *database << 0, 1, 2, 2, 1, 1, 0, 1, 2, 2, 2, 2, 5, 6, 1; -- -- std::vector<TopN> top_n( -- 2, TopN(3, std::make_pair(std::numeric_limits<float>::max(), -1))); -- -- auto leaf_searcher = LinearLeafSearcher::Create(database, DOT_PRODUCT); -- ASSERT_TRUE(leaf_searcher->FindNeighbors(query, &top_n)); -- -- auto extracted = top_n[0].Take(); -- EXPECT_THAT(extracted, ElementsAre(Pair(-22, 3), Pair(-17, 2), Pair(-8, 0))); -- -- extracted = top_n[1].Take(); -- EXPECT_THAT(extracted, ElementsAre(Pair(-14, 3), Pair(-10, 2), Pair(-8, 4))); --} -- --TEST_P(SearcherTest, AsymmetricHashNonSimd) { -- MatrixXf query(5, 2); -- query << 0, 1, 0, 1, 0, 1, 0, 1, 0, 1; -- std::shared_ptr<Matrix8u> database(new Matrix8u(2, 6)); -- *database << 0, 1, 2, 2, 1, 0, 1, 0, 1, 2, 2, 0; -- AsymmetricHashingProto proto; -- TextFormat::ParseFromString(kExampleAsymmetricHashingProtoString, &proto); -- auto querier = AsymmetricHashQuerier::Create(proto); -- auto leaf_searcher = -- AsymmetricHashLeafSearcher::Create(database, 0, std::move(querier)); -- std::vector<TopN> top_n; -- for (int i = 0; i < 2; ++i) { -- top_n.emplace_back( -- TopN(3, std::make_pair(std::numeric_limits<float>::max(), -1))); -- } -- ASSERT_TRUE(leaf_searcher->FindNeighbors(query, &top_n)); -- -- constexpr float kEps = 1e-5; -- auto extracted = top_n[0].Take(); -- EXPECT_NEAR(0.19, extracted[0].first, kEps); -- EXPECT_NEAR(0.19, extracted[1].first, kEps); -- EXPECT_NEAR(0.19, extracted[2].first, kEps); -- -- extracted = top_n[1].Take(); -- EXPECT_NEAR(4.39, extracted[0].first, kEps); -- EXPECT_NEAR(5.79, extracted[1].first, kEps); -- EXPECT_NEAR(5.79, extracted[2].first, kEps); --} -- --#if defined(__ARM_NEON__) || defined(__SSE__) --TEST_P(SearcherTest, AsymmetricHashSimdFloat32x4) { -- MatrixXf query(5, 6); -- query << 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, -- 1, 0, 1, 0, 0, 1, 1; -- std::shared_ptr<Matrix8u> database(new Matrix8u(2, 9)); -- *database << 0, 0, 0, 1, 1, 1, 2, 2, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2; -- AsymmetricHashingProto proto; -- TextFormat::ParseFromString(kExampleAsymmetricHashingProtoString, &proto); -- auto querier = AsymmetricHashQuerier::Create(proto); -- auto leaf_searcher = AsymmetricHashLeafSearcher::Create( -- database, 0, std::move(querier), GetParam()); -- std::vector<TopN> top_n; -- for (int i = 0; i < 6; ++i) { -- top_n.emplace_back( -- TopN(3, std::make_pair(std::numeric_limits<float>::max(), -1))); -- } -- ASSERT_TRUE(leaf_searcher->FindNeighbors(query, &top_n)); -- -- constexpr float kEps = 1e-5; -- auto extracted = top_n[0].Take(); -- EXPECT_NEAR(0.19, extracted[0].first, kEps); -- EXPECT_NEAR(0.19, extracted[1].first, kEps); -- EXPECT_NEAR(0.19, extracted[2].first, kEps); -- -- extracted = top_n[1].Take(); -- EXPECT_NEAR(4.39, extracted[0].first, kEps); -- EXPECT_NEAR(4.39, extracted[1].first, kEps); -- EXPECT_NEAR(4.39, extracted[2].first, kEps); -- -- extracted = top_n[2].Take(); -- EXPECT_NEAR(0.19, extracted[0].first, kEps); -- EXPECT_NEAR(0.19, extracted[1].first, kEps); -- EXPECT_NEAR(1.59, extracted[2].first, kEps); -- -- extracted = top_n[3].Take(); -- EXPECT_NEAR(2.19, extracted[0].first, kEps); -- EXPECT_NEAR(2.19, extracted[1].first, kEps); -- EXPECT_NEAR(2.39, extracted[2].first, kEps); -- -- extracted = top_n[4].Take(); -- EXPECT_NEAR(3.59, extracted[0].first, kEps); -- EXPECT_NEAR(3.59, extracted[1].first, kEps); -- EXPECT_NEAR(3.59, extracted[2].first, kEps); -- -- extracted = top_n[5].Take(); -- EXPECT_NEAR(4.39, extracted[0].first, kEps); -- EXPECT_NEAR(4.39, extracted[1].first, kEps); -- EXPECT_NEAR(5.79, extracted[2].first, kEps); --} --#endif -- --#if defined(__ARM_NEON__) || defined(__SSE__) --TEST_P(SearcherTest, AsymmetricHashSimdInt16x8) { -- MatrixXf query(5, 11); -- query << 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, -- 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, -- 1, 1, 1, 0, 0, 0, 0; -- std::shared_ptr<Matrix8u> database(new Matrix8u(2, 9)); -- *database << 0, 0, 0, 1, 1, 1, 2, 2, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2; -- AsymmetricHashingProto proto; -- TextFormat::ParseFromString(kExampleAsymmetricHashingProtoString, &proto); -- proto.set_lookup_type(AsymmetricHashingProto::INT16); -- auto querier = AsymmetricHashQuerier::Create(proto); -- QueryInfo query_info; -- ASSERT_TRUE(querier->Process(query, &query_info)); -- -- auto leaf_searcher = AsymmetricHashLeafSearcher::Create( -- database, 0, std::move(querier), GetParam()); -- std::vector<TopN> top_n; -- for (int i = 0; i < 11; ++i) { -- top_n.emplace_back( -- TopN(3, std::make_pair(std::numeric_limits<float>::max(), -1))); -- } -- ASSERT_TRUE(leaf_searcher->FindNeighbors(query, &top_n)); -- -- auto extracted = top_n[0].Take(); -- constexpr float kEps = 5e-2; -- EXPECT_NEAR(0.19, extracted[0].first, kEps); -- EXPECT_NEAR(0.19, extracted[1].first, kEps); -- EXPECT_NEAR(0.19, extracted[2].first, kEps); -- -- extracted = top_n[1].Take(); -- EXPECT_NEAR(4.39, extracted[0].first, kEps); -- EXPECT_NEAR(4.39, extracted[1].first, kEps); -- EXPECT_NEAR(4.39, extracted[2].first, kEps); -- -- extracted = top_n[2].Take(); -- EXPECT_NEAR(0.19, extracted[0].first, kEps); -- EXPECT_NEAR(0.19, extracted[1].first, kEps); -- EXPECT_NEAR(1.59, extracted[2].first, kEps); -- -- extracted = top_n[3].Take(); -- EXPECT_NEAR(2.19, extracted[0].first, kEps); -- EXPECT_NEAR(2.19, extracted[1].first, kEps); -- EXPECT_NEAR(2.39, extracted[2].first, kEps); -- -- extracted = top_n[4].Take(); -- EXPECT_NEAR(3.59, extracted[0].first, kEps); -- EXPECT_NEAR(3.59, extracted[1].first, kEps); -- EXPECT_NEAR(3.59, extracted[2].first, kEps); -- -- extracted = top_n[5].Take(); -- EXPECT_NEAR(4.39, extracted[0].first, kEps); -- EXPECT_NEAR(4.39, extracted[1].first, kEps); -- EXPECT_NEAR(5.79, extracted[2].first, kEps); -- -- extracted = top_n[6].Take(); -- EXPECT_NEAR(1.39, extracted[0].first, kEps); -- EXPECT_NEAR(1.39, extracted[1].first, kEps); -- EXPECT_NEAR(1.79, extracted[2].first, kEps); -- -- extracted = top_n[7].Take(); -- EXPECT_NEAR(1.59, extracted[0].first, kEps); -- EXPECT_NEAR(1.59, extracted[1].first, kEps); -- EXPECT_NEAR(1.59, extracted[2].first, kEps); -- -- extracted = top_n[8].Take(); -- EXPECT_NEAR(1.39, extracted[0].first, kEps); -- EXPECT_NEAR(1.39, extracted[1].first, kEps); -- EXPECT_NEAR(1.79, extracted[2].first, kEps); -- -- extracted = top_n[9].Take(); -- EXPECT_NEAR(0.79, extracted[0].first, kEps); -- EXPECT_NEAR(0.79, extracted[1].first, kEps); -- EXPECT_NEAR(0.99, extracted[2].first, kEps); -- -- extracted = top_n[10].Take(); -- EXPECT_NEAR(0.79, extracted[0].first, kEps); -- EXPECT_NEAR(0.79, extracted[1].first, kEps); -- EXPECT_NEAR(0.79, extracted[2].first, kEps); --} --#endif -- --#if defined(__ARM_NEON__) || defined(__SSE__) --TEST_P(SearcherTest, AsymmetricHashMiniBatchedSimdFail) { -- std::shared_ptr<Matrix8u> database(new Matrix8u(2, 9)); -- *database << 0, 0, 0, 1, 1, 1, 2, 2, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2; -- AsymmetricHashingProto proto; -- TextFormat::ParseFromString(kExampleAsymmetricHashingProtoString, &proto); -- proto.set_lookup_type(AsymmetricHashingProto::FLOAT); -- proto.set_query_distance(DistanceMeasure::UNSPECIFIED); -- auto querier = AsymmetricHashQuerier::Create(proto); -- auto leaf_searcher = AsymmetricHashLeafSearcher::Create( -- database, 0, std::move(querier), GetParam()); -- -- MatrixXf queries(6, 6); -- queries << 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, -- 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0; -- std::vector<TopN> top_n; -- for (int i = 0; i < queries.cols(); ++i) { -- top_n.emplace_back( -- TopN(3, std::make_pair(std::numeric_limits<float>::max(), -1))); -- } -- EXPECT_FALSE(leaf_searcher->FindNeighbors(queries, &top_n)); --} --#endif -- --INSTANTIATE_TEST_SUITE_P( -- SearcherTest, -- SearcherTest, -- Values(std::numeric_limits<size_t>::max(), 1, 2, 3, 7, 23)); -- --} // namespace -- --} // namespace core --} // namespace scann_ondevice --} // namespace tflite -diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/simd_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/simd_utils.h -deleted file mode 100644 -index f239ec482382e..0000000000000 ---- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/simd_utils.h -+++ /dev/null -@@ -1,303 +0,0 @@ --/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. -- --Licensed under the Apache License, Version 2.0 (the "License"); --you may not use this file except in compliance with the License. --You may obtain a copy of the License at -- -- http://www.apache.org/licenses/LICENSE-2.0 -- --Unless required by applicable law or agreed to in writing, software --distributed under the License is distributed on an "AS IS" BASIS, --WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. --See the License for the specific language governing permissions and --limitations under the License. --==============================================================================*/ -- --#ifndef TENSORFLOW_LITE_SUPPORT_SCANN_ONDEVICE_CC_CORE_SIMD_UTILS_H_ --#define TENSORFLOW_LITE_SUPPORT_SCANN_ONDEVICE_CC_CORE_SIMD_UTILS_H_ -- --#include <cstdint> --#ifdef __SSE__ --#include <x86intrin.h> --#endif --#ifdef __ARM_NEON__ --#include <arm_neon.h> --#endif -- --#include <cmath> --#include <memory> -- --#include "tensorflow_lite_support/cc/port/integral_types.h" -- --namespace tflite { --namespace scann_ondevice { --namespace core { --class SimdFloat32x1 { -- float value_; -- -- public: -- static constexpr size_t size() { return 1; } -- -- void setzero() { value_ = 0.0f; } -- -- void load(const float* mem) { value_ = *mem; } -- void dequantize_accum_storeu(float* mem, float, float) const { -- *mem += value_; -- } -- SimdFloat32x1& operator+=(const SimdFloat32x1& rhs) { -- value_ += rhs.value_; -- return *this; -- } --}; --#ifdef __SSE__ --class SimdFloat32x4 { -- __m128 value_; -- -- public: -- static constexpr size_t size() { return 4; } -- -- void setzero() { value_ = _mm_setzero_ps(); } -- void load(const float* mem) { value_ = _mm_load_ps(mem); } -- void loadu(const float* mem) { value_ = _mm_loadu_ps(mem); } -- void storeu(float* mem) const { _mm_storeu_ps(mem, value_); } -- -- void dequantize_accum_storeu(float* mem, float, float) const { -- SimdFloat32x4 simd; -- simd.loadu(mem); -- simd += *this; -- simd.storeu(mem); -- } -- -- SimdFloat32x4& operator+=(const SimdFloat32x4& rhs) { -- value_ = _mm_add_ps(rhs.value_, value_); -- return *this; -- } --}; --#endif --#ifdef __AVX__ --class SimdFloat32x8 { -- __m256 value_; -- -- public: -- static constexpr size_t size() { return 8; } -- -- void setzero() { value_ = _mm256_setzero_ps(); } -- void load(const float* mem) { value_ = _mm256_load_ps(mem); } -- -- void loadu(const float* mem) { value_ = _mm256_loadu_ps(mem); } -- -- void storeu(float* mem) { _mm256_storeu_ps(mem, value_); } -- -- void dequantize_accum_storeu(float* mem, float, float) const { -- SimdFloat32x8 simd; -- simd.loadu(mem); -- simd += *this; -- simd.storeu(mem); -- } -- -- SimdFloat32x8& operator+=(const SimdFloat32x8& rhs) { -- value_ = _mm256_add_ps(rhs.value_, value_); -- return *this; -- } --}; --#endif --#ifdef __ARM_NEON__ --class SimdFloat32x4 { -- float32x4_t value_; -- -- public: -- static constexpr size_t size() { return 4; } -- -- void setzero() { value_ = vmovq_n_f32(0); } -- void load(const float* mem) { value_ = vld1q_f32(mem); } -- void loadu(const float* mem) { value_ = vld1q_f32(mem); } -- void storeu(float* mem) const { vst1q_f32(mem, value_); } -- -- void dequantize_accum_storeu(float* mem, float, float) const { -- SimdFloat32x4 simd; -- simd.loadu(mem); -- simd += *this; -- simd.storeu(mem); -- } -- -- SimdFloat32x4& operator+=(const SimdFloat32x4& rhs) { -- value_ = vaddq_f32(rhs.value_, value_); -- return *this; -- } --}; --#endif -- --class SimdInt16x1 { -- uint16_t value_; -- -- public: -- static constexpr size_t size() { return 1; } -- -- void setzero() { value_ = 0; } -- -- void load(const uint16_t* mem) { value_ = *mem; } -- -- void load(const uint8_t* mem) { value_ = *mem; } -- void dequantize_accum_storeu(float* mem, float scale, float offset) const { -- *mem += scale * value_ + offset; -- } -- -- SimdInt16x1& operator+=(const SimdInt16x1& rhs) { -- value_ += rhs.value_; -- return *this; -- } --}; --#ifdef __SSE4_1__ --class SimdInt16x8 { -- __m128i value_; -- -- public: -- static constexpr size_t size() { return 8; } -- -- void setzero() { value_ = _mm_setzero_si128(); } -- void load(const uint16_t* mem) { -- value_ = _mm_load_si128(reinterpret_cast<const __m128i*>(mem)); -- } -- -- void loadu(const uint16_t* mem) { -- value_ = _mm_loadu_si128(reinterpret_cast<const __m128i*>(mem)); -- } -- void load(const uint8_t* mem) { -- __m128i tmp = _mm_loadl_epi64(reinterpret_cast<const __m128i*>(mem)); -- value_ = _mm_cvtepu8_epi16(tmp); -- } -- -- void loadu(const uint8_t* mem) { -- __m128i tmp = _mm_loadl_epi64(reinterpret_cast<const __m128i*>(mem)); -- value_ = _mm_cvtepu8_epi16(tmp); -- } -- void dequantize_accum_storeu(float* mem, float scale, float offset) const { -- __m128 dst0 = _mm_loadu_ps(mem); -- __m128 dst1 = _mm_loadu_ps(mem + 4); -- __m128i lo_i16 = value_; -- __m128i hi_i16 = _mm_unpackhi_epi64(value_, value_); -- __m128i lo_i32 = _mm_cvtepu16_epi32(lo_i16); -- __m128i hi_i32 = _mm_cvtepu16_epi32(hi_i16); -- __m128 lo_f32 = _mm_cvtepi32_ps(lo_i32); -- __m128 hi_f32 = _mm_cvtepi32_ps(hi_i32); -- __m128 offset_simd = _mm_set1_ps(offset); -- __m128 scale_simd = _mm_set1_ps(scale); -- lo_f32 = _mm_mul_ps(scale_simd, lo_f32); -- hi_f32 = _mm_mul_ps(scale_simd, hi_f32); -- lo_f32 = _mm_add_ps(lo_f32, offset_simd); -- hi_f32 = _mm_add_ps(hi_f32, offset_simd); -- dst0 = _mm_add_ps(dst0, lo_f32); -- dst1 = _mm_add_ps(dst1, hi_f32); -- _mm_storeu_ps(mem, dst0); -- _mm_storeu_ps(mem + 4, dst1); -- } -- -- SimdInt16x8& operator+=(const SimdInt16x8& rhs) { -- value_ = _mm_add_epi16(rhs.value_, value_); -- return *this; -- } --}; --#endif --#ifdef __ARM_NEON__ --class SimdInt16x8 { -- uint16x8_t value_; -- -- public: -- static constexpr size_t size() { return 8; } -- -- void setzero() { value_ = vmovq_n_u16(0); } -- void load(const uint16* mem) { value_ = vld1q_u16(mem); } -- -- void loadu(const uint16* mem) { value_ = vld1q_u16(mem); } -- void load(const uint8* mem) { -- uint8x8_t tmp = vld1_u8(mem); -- value_ = vmovl_u8(tmp); -- } -- -- void loadu(const uint8* mem) { -- uint8x8_t tmp = vld1_u8(mem); -- value_ = vmovl_u8(tmp); -- } -- void dequantize_accum_storeu(float* mem, float scale, float offset) const { -- float32x4_t dst0 = vld1q_f32(mem); -- float32x4_t dst1 = vld1q_f32(mem + 4); -- uint16x4_t lo_i16 = vget_low_u16(value_); -- uint16x4_t hi_i16 = vget_high_u16(value_); -- uint32x4_t lo_i32 = vmovl_u16(lo_i16); -- uint32x4_t hi_i32 = vmovl_u16(hi_i16); -- float32x4_t lo_f32 = vcvtq_f32_u32(lo_i32); -- float32x4_t hi_f32 = vcvtq_f32_u32(hi_i32); -- float32x4_t offset_simd = vdupq_n_f32(offset); -- float32x4_t scale_simd = vdupq_n_f32(scale); -- lo_f32 = vmlaq_f32(offset_simd, scale_simd, lo_f32); -- hi_f32 = vmlaq_f32(offset_simd, scale_simd, hi_f32); -- dst0 = vaddq_f32(dst0, lo_f32); -- dst1 = vaddq_f32(dst1, hi_f32); -- vst1q_f32(mem, dst0); -- vst1q_f32(mem + 4, dst1); -- } -- -- SimdInt16x8& operator+=(const SimdInt16x8& rhs) { -- value_ = vaddq_u16(rhs.value_, value_); -- return *this; -- } --}; --#endif --#ifdef __AVX2__ --class SimdInt16x16 { -- __m256i value_; -- -- public: -- static constexpr size_t size() { return 16; } -- -- void setzero() { value_ = _mm256_setzero_si256(); } -- -- void load(const uint16* mem) { -- value_ = _mm256_load_si256(reinterpret_cast<const __m256i*>(mem)); -- } -- -- void loadu(const uint16* mem) { -- value_ = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(mem)); -- } -- -- void load(const uint8* mem) { -- __m128i tmp = _mm_load_si128(reinterpret_cast<const __m128i*>(mem)); -- value_ = _mm256_cvtepu8_epi16(tmp); -- } -- -- void loadu(const uint8* mem) { -- __m128i tmp = _mm_loadu_si128(reinterpret_cast<const __m128i*>(mem)); -- value_ = _mm256_cvtepu8_epi16(tmp); -- } -- -- void dequantize_accum_storeu(float* mem, float scale, float offset) const { -- __m256 dst0 = _mm256_loadu_ps(mem); -- __m256 dst1 = _mm256_loadu_ps(mem + 8); -- __m128i lo_i16 = _mm256_castsi256_si128(value_); -- __m128i hi_i16 = _mm256_extractf128_si256(value_, 1); -- __m256i lo_i32 = _mm256_cvtepu16_epi32(lo_i16); -- __m256i hi_i32 = _mm256_cvtepu16_epi32(hi_i16); -- __m256 lo_f32 = _mm256_cvtepi32_ps(lo_i32); -- __m256 hi_f32 = _mm256_cvtepi32_ps(hi_i32); -- __m256 offset_simd = _mm256_set1_ps(offset); -- __m256 scale_simd = _mm256_set1_ps(scale); -- lo_f32 = _mm256_fmadd_ps(scale_simd, lo_f32, offset_simd); -- hi_f32 = _mm256_fmadd_ps(scale_simd, hi_f32, offset_simd); -- dst0 = _mm256_add_ps(dst0, lo_f32); -- dst1 = _mm256_add_ps(dst1, hi_f32); -- _mm256_storeu_ps(mem, dst0); -- _mm256_storeu_ps(mem + 8, dst1); -- } -- -- SimdInt16x16& operator+=(const SimdInt16x16& rhs) { -- value_ = _mm256_add_epi16(rhs.value_, value_); -- return *this; -- } --}; --#endif -- --} // namespace core --} // namespace scann_ondevice --} // namespace tflite -- --#endif // TENSORFLOW_LITE_SUPPORT_SCANN_ONDEVICE_CC_CORE_SIMD_UTILS_H_ -diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index.cc b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index.cc -deleted file mode 100644 -index e8be5f6572f17..0000000000000 ---- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index.cc -+++ /dev/null -@@ -1,138 +0,0 @@ --/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. -- --Licensed under the Apache License, Version 2.0 (the "License"); --you may not use this file except in compliance with the License. --You may obtain a copy of the License at -- -- http://www.apache.org/licenses/LICENSE-2.0 -- --Unless required by applicable law or agreed to in writing, software --distributed under the License is distributed on an "AS IS" BASIS, --WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. --See the License for the specific language governing permissions and --limitations under the License. --==============================================================================*/ -- --#include "tensorflow_lite_support/scann_ondevice/cc/index.h" -- --#include <cstddef> --#include <memory> -- --#include "absl/memory/memory.h" // from @com_google_absl --#include "absl/status/status.h" // from @com_google_absl --#include "absl/status/statusor.h" // from @com_google_absl --#include "absl/strings/str_format.h" // from @com_google_absl --#include "absl/strings/string_view.h" // from @com_google_absl --#include "leveldb/cache.h" // from @com_google_leveldb --#include "leveldb/iterator.h" // from @com_google_leveldb --#include "leveldb/options.h" // from @com_google_leveldb --#include "leveldb/slice.h" // from @com_google_leveldb --#include "leveldb/status.h" // from @com_google_leveldb --#include "leveldb/table.h" // from @com_google_leveldb --#include "tensorflow_lite_support/cc/port/status_macros.h" --#include "tensorflow_lite_support/scann_ondevice/cc/mem_random_access_file.h" --#include "tensorflow_lite_support/scann_ondevice/cc/utils.h" --#include "tensorflow_lite_support/scann_ondevice/proto/index_config.pb.h" -- --namespace tflite { --namespace scann_ondevice { -- --namespace { -- --// Helper function to get the iterator value associated to the provided key. --// --// Important: the underlying storage for the returned string view is owned by --// the provided iterator, and only valid until this iterator is used again with --// a different key. See: --// https://github.com/google/leveldb/blob/main/include/leveldb/iterator.h --absl::StatusOr<absl::string_view> GetValueForKey(leveldb::Iterator* iterator, -- std::string& key) { -- iterator->Seek(key); -- if (!iterator->Valid() || iterator->key() != key || -- !iterator->status().ok()) { -- return absl::NotFoundError( -- absl::StrFormat("Unable to find key in the index: %s", key)); -- } -- leveldb::Slice value = iterator->value(); -- return absl::string_view(value.data(), value.size()); --} --} // namespace -- --/* static */ --absl::StatusOr<std::unique_ptr<Index>> Index::CreateFromIndexBuffer( -- const char* buffer_data, -- size_t buffer_size) { -- // Use absl::WrapUnique() to call private constructor: -- // https://abseil.io/tips/126. -- std::unique_ptr<Index> index = absl::WrapUnique(new Index()); -- RETURN_IF_ERROR(index->InitFromBuffer(buffer_data, buffer_size)); -- return index; --} -- --absl::StatusOr<IndexConfig> Index::GetIndexConfig() const { -- std::string key(kIndexConfigKey); -- ASSIGN_OR_RETURN(absl::string_view value, -- GetValueForKey(config_iterator_.get(), key)); -- IndexConfig config; -- if (!config.ParseFromString(std::string(value))) { -- return absl::InternalError("Unable to parse IndexConfig proto"); -- } -- return config; --} -- --absl::StatusOr<absl::string_view> Index::GetUserInfo() const { -- std::string key(kUserInfoKey); -- // Intercept NotFound errors and return empty string instead. -- auto user_info_or = GetValueForKey(info_iterator_.get(), key); -- if (user_info_or.status().code() == absl::StatusCode::kNotFound) { -- return ""; -- } -- return user_info_or; --} -- --absl::StatusOr<absl::string_view> Index::GetPartitionAtIndex(uint32_t i) const { -- std::string key(GetPartitionKey(i)); -- return GetValueForKey(embedding_iterator_.get(), key); --} -- --absl::StatusOr<absl::string_view> Index::GetMetadataAtIndex(uint32_t i) const { -- std::string key(GetMetadataKey(i)); -- return GetValueForKey(metadata_iterator_.get(), key); --} -- --absl::Status Index::InitFromBuffer(const char* buffer_data, -- size_t buffer_size) { -- // Sanity check. -- if (buffer_data == nullptr) { -- return absl::InvalidArgumentError("Buffer cannot be null"); -- } -- // Create file from buffer. -- file_ = absl::make_unique<MemRandomAccessFile>(buffer_data, buffer_size); -- // Create options with cache disabled, as this saves memory and has negligible -- // impact on performance in this setup as any key can be accessed anytime. -- leveldb::Options options; -- cache_ = absl::WrapUnique(leveldb::NewLRUCache(0)); -- options.block_cache = cache_.get(); -- // Build Table from file and options. -- leveldb::Table* table; -- leveldb::Status status = -- leveldb::Table::Open(options, file_.get(), buffer_size, &table); -- if (!status.ok()) { -- return absl::InternalError( -- absl::StrFormat("Unable to open levelDB table: %s", status.ToString())); -- } -- table_ = absl::WrapUnique(table); -- // Create iterators. -- config_iterator_ = -- absl::WrapUnique(table_->NewIterator(leveldb::ReadOptions())); -- info_iterator_ = -- absl::WrapUnique(table_->NewIterator(leveldb::ReadOptions())); -- embedding_iterator_ = -- absl::WrapUnique(table_->NewIterator(leveldb::ReadOptions())); -- metadata_iterator_ = -- absl::WrapUnique(table_->NewIterator(leveldb::ReadOptions())); -- return absl::OkStatus(); --} -- --} // namespace scann_ondevice --} // namespace tflite -diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index.h b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index.h -deleted file mode 100644 -index 15e709183a606..0000000000000 ---- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index.h -+++ /dev/null -@@ -1,91 +0,0 @@ --/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. -- --Licensed under the Apache License, Version 2.0 (the "License"); --you may not use this file except in compliance with the License. --You may obtain a copy of the License at -- -- http://www.apache.org/licenses/LICENSE-2.0 -- --Unless required by applicable law or agreed to in writing, software --distributed under the License is distributed on an "AS IS" BASIS, --WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. --See the License for the specific language governing permissions and --limitations under the License. --==============================================================================*/ -- --#ifndef TENSORFLOW_LITE_SUPPORT_SCANN_ONDEVICE_CC_INDEX_H_ --#define TENSORFLOW_LITE_SUPPORT_SCANN_ONDEVICE_CC_INDEX_H_ -- --#include <memory> -- --#include "absl/status/status.h" // from @com_google_absl --#include "absl/status/statusor.h" // from @com_google_absl --#include "absl/strings/string_view.h" // from @com_google_absl --#include "leveldb/cache.h" // from @com_google_leveldb --#include "leveldb/iterator.h" // from @com_google_leveldb --#include "leveldb/table.h" // from @com_google_leveldb --#include "tensorflow_lite_support/scann_ondevice/cc/mem_random_access_file.h" --#include "tensorflow_lite_support/scann_ondevice/proto/index_config.pb.h" -- --namespace tflite { --namespace scann_ondevice { -- --// Helper class for getting access to the data contained in the LevelDB index --// file. --// --// This class is NOT thread-safe. --class Index { -- public: -- // Creates an Index from the provided buffer. Ownership is transferred to the -- // caller. Returns an error if the creation failed, which may happen e.g. if -- // the provided buffer is not a valid LevelDB index file. -- // -- // Warning: Does not take ownership of the provided buffer, which must outlive -- // this object. -- static absl::StatusOr<std::unique_ptr<Index>> CreateFromIndexBuffer( -- const char* buffer_data, -- size_t buffer_size); -- -- // Parses and returns the `IndexConfig` stored in the index file. -- absl::StatusOr<IndexConfig> GetIndexConfig() const; -- -- // Provides access to the opaque user info stored in the index file (if any), -- // in raw binary form. Returns an empty string if the index doesn't contain -- // user info. -- absl::StatusOr<absl::string_view> GetUserInfo() const; -- -- // Provides access to the partition data corresponding to the i-th leaf in the -- // order specified in the `IndexConfig`, in raw binary form. -- // -- // Warning: In order to avoid unnecessary copies, the underlying pointer for -- // the returned string view is only valid until next call to this method. -- absl::StatusOr<absl::string_view> GetPartitionAtIndex(uint32_t i) const; -- -- // Provides access to the metadata associated with the i-th embedding in the -- // index, in raw binary form. -- // -- // Warning: In order to avoid unnecessary copies, the underlying pointer for -- // the returned string view is only valid until next call to this method. -- absl::StatusOr<absl::string_view> GetMetadataAtIndex(uint32_t i) const; -- -- private: -- // Private default constructor, called from CreateFromBuffer(). -- Index() = default; -- // Initializes the Index from the provided buffer. -- absl::Status InitFromBuffer(const char* buffer_data, size_t buffer_size); -- -- std::unique_ptr<leveldb::Table> table_; -- std::unique_ptr<MemRandomAccessFile> file_; -- std::unique_ptr<leveldb::Cache> cache_; -- // One iterator per getter, so that calls from one getter don't invalidate -- // results from another one. -- std::unique_ptr<leveldb::Iterator> config_iterator_; -- std::unique_ptr<leveldb::Iterator> info_iterator_; -- std::unique_ptr<leveldb::Iterator> embedding_iterator_; -- std::unique_ptr<leveldb::Iterator> metadata_iterator_; --}; -- --} // namespace scann_ondevice --} // namespace tflite -- --#endif // TENSORFLOW_LITE_SUPPORT_SCANN_ONDEVICE_CC_INDEX_H_ -diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index_builder.cc b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index_builder.cc -deleted file mode 100644 -index c77f7299e64a6..0000000000000 ---- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index_builder.cc -+++ /dev/null -@@ -1,177 +0,0 @@ --/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. -- --Licensed under the Apache License, Version 2.0 (the "License"); --you may not use this file except in compliance with the License. --You may obtain a copy of the License at -- -- http://www.apache.org/licenses/LICENSE-2.0 -- --Unless required by applicable law or agreed to in writing, software --distributed under the License is distributed on an "AS IS" BASIS, --WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. --See the License for the specific language governing permissions and --limitations under the License. --==============================================================================*/ -- --#include "tensorflow_lite_support/scann_ondevice/cc/index_builder.h" -- --#include <string> --#include <tuple> --#include <vector> -- --#include "absl/container/btree_map.h" // from @com_google_absl --#include "absl/status/status.h" // from @com_google_absl --#include "absl/strings/str_format.h" // from @com_google_absl --#include "leveldb/options.h" // from @com_google_leveldb --#include "leveldb/slice.h" // from @com_google_leveldb --#include "leveldb/status.h" // from @com_google_leveldb --#include "leveldb/table_builder.h" // from @com_google_leveldb --#include "leveldb/write_batch.h" // from @com_google_leveldb --#include "tensorflow_lite_support/cc/port/status_macros.h" --#include "tensorflow_lite_support/scann_ondevice/cc/mem_writable_file.h" --#include "tensorflow_lite_support/scann_ondevice/cc/utils.h" --#include "tensorflow_lite_support/scann_ondevice/proto/index_config.pb.h" -- --namespace tflite { --namespace scann_ondevice { -- --namespace { -- --absl::Status LevelDBStatusToAbsl(leveldb::Status leveldb_status) { -- if (leveldb_status.ok()) { -- return absl::OkStatus(); -- } else if (leveldb_status.IsInvalidArgument()) { -- return absl::InvalidArgumentError(leveldb_status.ToString()); -- } else if (leveldb_status.IsNotFound()) { -- return absl::NotFoundError(leveldb_status.ToString()); -- } else if (leveldb_status.IsNotSupportedError()) { -- return absl::UnimplementedError(leveldb_status.ToString()); -- } else { -- return absl::InternalError(leveldb_status.ToString()); -- } --} -- --template <typename T> --absl::StatusOr<std::string> CreateIndexBufferImpl( -- absl::Span<const T> database, -- absl::Span<const uint32_t> partition_assignment, -- absl::Span<const std::string> metadata, -- const std::string& userinfo, -- IndexConfig index_config, -- bool compression) { -- if (partition_assignment.size() != metadata.size()) { -- return absl::InvalidArgumentError( -- "Size of partition assignment and metadata mismatch"); -- } -- -- if (database.size() / index_config.embedding_dim() != metadata.size()) { -- return absl::InvalidArgumentError( -- "Number of embeddings differs from number of metadata"); -- } -- -- const size_t num_partitions = -- index_config.scann_config().partitioner().leaf_size(); -- -- std::vector<std::vector<char>> partition_bytes(num_partitions); -- std::vector<std::vector<std::string>> partition_metadata(num_partitions); -- -- const size_t per_embedding_bytes = sizeof(T) * index_config.embedding_dim(); -- const char* database_bytes = reinterpret_cast<const char*>(database.data()); -- for (size_t i = 0; i < partition_assignment.size(); ++i) { -- const size_t partition_idx = partition_assignment[i]; -- if (partition_idx >= num_partitions) { -- return absl::InvalidArgumentError(absl::StrFormat( -- "Partition index %d is larger than number of partitions: %d", -- partition_idx, num_partitions)); -- } -- partition_bytes[partition_idx].insert( -- partition_bytes[partition_idx].end(), -- database_bytes + i * per_embedding_bytes, -- database_bytes + (i + 1) * per_embedding_bytes); -- partition_metadata[partition_idx].push_back(metadata[i]); -- } -- -- std::vector<std::string> flatten_metadata; -- flatten_metadata.reserve(metadata.size()); -- for (auto partition : partition_metadata) { -- const size_t offset = flatten_metadata.size(); -- index_config.mutable_global_partition_offsets()->Add(offset); -- flatten_metadata.insert(flatten_metadata.end(), partition.begin(), -- partition.end()); -- partition.clear(); -- partition.shrink_to_fit(); -- } -- -- std::string buffer; -- ASSIGN_OR_RETURN(auto mem_writable_file, MemWritableFile::Create(&buffer)); -- -- leveldb::Options options; -- options.compression = -- compression ? leveldb::kSnappyCompression : leveldb::kNoCompression; -- leveldb::TableBuilder table_builder(options, mem_writable_file.get()); -- -- // Keys must be added in ascending *lexical* order, e.g: -- // E_0, E_1, E_10, E_11, [...], E_18, E_19, E_2, E_20, E_21, [...] -- // We're using btree_map to reorder partition and metadata keys. -- absl::btree_map<std::string, size_t> ordered_partition_key_to_index; -- for (size_t i = 0; i < partition_bytes.size(); ++i) { -- ordered_partition_key_to_index[GetPartitionKey(i)] = i; -- } -- for (auto [key, index] : ordered_partition_key_to_index) { -- table_builder.Add(leveldb::Slice(key), -- leveldb::Slice(partition_bytes[index].data(), -- partition_bytes[index].size())); -- } -- table_builder.Add(leveldb::Slice(kIndexConfigKey), -- leveldb::Slice(index_config.SerializeAsString())); -- absl::btree_map<std::string, size_t> ordered_metadata_key_to_index; -- for (size_t i = 0; i < flatten_metadata.size(); ++i) { -- ordered_metadata_key_to_index[GetMetadataKey(i)] = i; -- } -- for (auto [key, index] : ordered_metadata_key_to_index) { -- table_builder.Add(leveldb::Slice(key), -- leveldb::Slice(flatten_metadata[index])); -- } -- table_builder.Add(leveldb::Slice(kUserInfoKey), leveldb::Slice(userinfo)); -- -- const auto status = table_builder.Finish(); -- if (!status.ok()) { -- return LevelDBStatusToAbsl(status); -- } -- -- return buffer; --} -- --} // namespace -- --absl::StatusOr<std::string> CreateIndexBuffer(const IndexedArtifacts& artifacts, -- bool compression) { -- if (artifacts.hashed_database.has_value() && -- artifacts.float_database.has_value()) { -- return absl::InvalidArgumentError( -- "Can not have both float database and hashed database"); -- } -- -- IndexConfig index_config; -- *index_config.mutable_scann_config() = artifacts.config; -- index_config.set_embedding_dim(artifacts.embedding_dim); -- if (artifacts.hashed_database.has_value()) { -- index_config.set_embedding_type(index_config.UINT8); -- return CreateIndexBufferImpl(artifacts.hashed_database.value(), -- artifacts.partition_assignment, -- artifacts.metadata, artifacts.userinfo, -- std::move(index_config), compression); -- } else if (artifacts.float_database.has_value()) { -- index_config.set_embedding_type(index_config.FLOAT); -- return CreateIndexBufferImpl(artifacts.float_database.value(), -- artifacts.partition_assignment, -- artifacts.metadata, artifacts.userinfo, -- std::move(index_config), compression); -- } else { -- return absl::InvalidArgumentError( -- "Need either hashed_database or float_database"); -- } --} -- --} // namespace scann_ondevice --} // namespace tflite -diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index_builder.h b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index_builder.h -deleted file mode 100644 -index 5701796943e28..0000000000000 ---- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index_builder.h -+++ /dev/null -@@ -1,68 +0,0 @@ --/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. -- --Licensed under the Apache License, Version 2.0 (the "License"); --you may not use this file except in compliance with the License. --You may obtain a copy of the License at -- -- http://www.apache.org/licenses/LICENSE-2.0 -- --Unless required by applicable law or agreed to in writing, software --distributed under the License is distributed on an "AS IS" BASIS, --WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. --See the License for the specific language governing permissions and --limitations under the License. --==============================================================================*/ -- --#ifndef TENSORFLOW_LITE_SUPPORT_SCANN_ONDEVICE_CC_INDEX_FILE_MUTATOR_H_ --#define TENSORFLOW_LITE_SUPPORT_SCANN_ONDEVICE_CC_INDEX_FILE_MUTATOR_H_ -- --#include "absl/status/statusor.h" // from @com_google_absl --#include "absl/strings/string_view.h" // from @com_google_absl --#include "absl/types/optional.h" // from @com_google_absl --#include "absl/types/span.h" // from @com_google_absl --#include "leveldb/db.h" // from @com_google_leveldb --#include "tensorflow_lite_support/scann_ondevice/cc/core/serialized_searcher.pb.h" --#include "tensorflow_lite_support/scann_ondevice/proto/index_config.pb.h" -- --namespace tflite { --namespace scann_ondevice { -- --struct IndexedArtifacts { -- // Config for on-device scam. Contains pretrained parts such as partition -- // centroids, compression codebook. -- tflite::scann_ondevice::core::ScannOnDeviceConfig config; -- -- // The dimension of each processed embedding in either hashed_database or -- // float_database. Note that if hashing is enabled, it can be different from -- // the original embedding dimension depending on the config. -- uint32_t embedding_dim; -- -- // Flattened database embeddings. The embeddings should be stored -- // consecutively in row major layout. Exactly one of the hashed_database and -- // float_database is expected. hashed_database can be either AH compressed or -- // 8-bit quantized. In the case of 8-bit quantization, it's casted to uint8_t. -- absl::optional<absl::Span<const uint8_t>> hashed_database; -- absl::optional<absl::Span<const float>> float_database; -- -- // The partition each of the database point belongs to. The size should be the -- // same as how many database points there are. -- absl::Span<const uint32_t> partition_assignment; -- -- // The metadata (label) for each database point. It should have the same size -- // as partition_assignment. -- absl::Span<const std::string> metadata; -- -- // An arbitrary user supplied string for storing custom information. -- std::string userinfo; --}; -- --// Creates a byte buffer for the index file from the artifacts. Returns errors --// when there are not exactly one database specified, or other issues with input --// such as shape mismatch, invalid partition indices etc. --absl::StatusOr<std::string> CreateIndexBuffer(const IndexedArtifacts& artifacts, -- bool compression); -- --} // namespace scann_ondevice --} // namespace tflite -- --#endif // TENSORFLOW_LITE_SUPPORT_SCANN_ONDEVICE_CC_INDEX_FILE_MUTATOR_H_ -diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/mem_random_access_file.cc b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/mem_random_access_file.cc -deleted file mode 100644 -index 59b9deb8e8682..0000000000000 ---- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/mem_random_access_file.cc -+++ /dev/null -@@ -1,52 +0,0 @@ --/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. -- --Licensed under the Apache License, Version 2.0 (the "License"); --you may not use this file except in compliance with the License. --You may obtain a copy of the License at -- -- http://www.apache.org/licenses/LICENSE-2.0 -- --Unless required by applicable law or agreed to in writing, software --distributed under the License is distributed on an "AS IS" BASIS, --WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. --See the License for the specific language governing permissions and --limitations under the License. --==============================================================================*/ -- --#include "tensorflow_lite_support/scann_ondevice/cc/mem_random_access_file.h" -- --#include <algorithm> --#include <cstddef> --#include <cstdint> -- --#include "leveldb/env.h" // from @com_google_leveldb --#include "leveldb/slice.h" // from @com_google_leveldb --#include "leveldb/status.h" // from @com_google_leveldb -- --namespace tflite { --namespace scann_ondevice { -- --MemRandomAccessFile::MemRandomAccessFile(const char* buffer_data, -- size_t buffer_size) -- : buffer_data_(buffer_data), buffer_size_(buffer_size) {} -- --MemRandomAccessFile::~MemRandomAccessFile() {} -- --leveldb::Status MemRandomAccessFile::Read(uint64_t offset, -- size_t n, -- leveldb::Slice* result, -- char* scratch) const { -- // Sanity check. -- if (offset > buffer_size_) { -- return leveldb::Status::InvalidArgument( -- "Read offset is beyond buffer size"); -- } -- // Truncate result if the requested chunk extends beyond the buffer. -- const size_t result_size = -- std::min(n, buffer_size_ - static_cast<size_t>(offset)); -- *result = leveldb::Slice(buffer_data_ + offset, result_size); -- return leveldb::Status::OK(); --} -- --} // namespace scann_ondevice --} // namespace tflite -diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/mem_random_access_file.h b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/mem_random_access_file.h -deleted file mode 100644 -index 5ca68f2e2c91e..0000000000000 ---- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/mem_random_access_file.h -+++ /dev/null -@@ -1,61 +0,0 @@ --/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. -- --Licensed under the Apache License, Version 2.0 (the "License"); --you may not use this file except in compliance with the License. --You may obtain a copy of the License at -- -- http://www.apache.org/licenses/LICENSE-2.0 -- --Unless required by applicable law or agreed to in writing, software --distributed under the License is distributed on an "AS IS" BASIS, --WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. --See the License for the specific language governing permissions and --limitations under the License. --==============================================================================*/ -- --#ifndef TENSORFLOW_LITE_SUPPORT_SCANN_ONDEVICE_CC_MEM_RANDOM_ACCESS_FILE_H_ --#define TENSORFLOW_LITE_SUPPORT_SCANN_ONDEVICE_CC_MEM_RANDOM_ACCESS_FILE_H_ -- --#include <cstddef> --#include <cstdint> -- --#include "leveldb/env.h" // from @com_google_leveldb --#include "leveldb/slice.h" // from @com_google_leveldb --#include "leveldb/status.h" // from @com_google_leveldb -- --namespace tflite { --namespace scann_ondevice { -- --// An implementation of LevelDB's RandomAccessFile [1] that wraps an in-memory --// buffer. --// --// [1]: https://github.com/google/leveldb/blob/main/include/leveldb/env.h --class MemRandomAccessFile : public leveldb::RandomAccessFile { -- public: -- // Constructor does not take ownership of the provided buffer, which must -- // outlive this object. -- MemRandomAccessFile(const char* buffer_data, size_t buffer_size); -- ~MemRandomAccessFile() override; -- -- // Override of the `Read` function. Note that `scratch` is unused in the -- // implementation. -- leveldb::Status Read(uint64_t offset, -- size_t n, -- leveldb::Slice* result, -- char* scratch) const override; -- -- // Class is movable and non-copyable. -- MemRandomAccessFile(MemRandomAccessFile&& rhs) = default; -- MemRandomAccessFile& operator=(MemRandomAccessFile&& rhs) = default; -- MemRandomAccessFile(const MemRandomAccessFile& rhs) = delete; -- MemRandomAccessFile& operator=(const MemRandomAccessFile& rhs) = delete; -- -- private: -- const char* buffer_data_; -- size_t buffer_size_; --}; -- --} // namespace scann_ondevice --} // namespace tflite -- --#endif // TENSORFLOW_LITE_SUPPORT_SCANN_ONDEVICE_CC_MEM_RANDOM_ACCESS_FILE_H_ -diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/mem_writable_file.h b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/mem_writable_file.h -deleted file mode 100644 -index 842e837927d4e..0000000000000 ---- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/mem_writable_file.h -+++ /dev/null -@@ -1,64 +0,0 @@ --/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. -- --Licensed under the Apache License, Version 2.0 (the "License"); --you may not use this file except in compliance with the License. --You may obtain a copy of the License at -- -- http://www.apache.org/licenses/LICENSE-2.0 -- --Unless required by applicable law or agreed to in writing, software --distributed under the License is distributed on an "AS IS" BASIS, --WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. --See the License for the specific language governing permissions and --limitations under the License. --==============================================================================*/ -- --#ifndef TENSORFLOW_LITE_SUPPORT_SCANN_ONDEVICE_CC_MEM_WRITABLE_FILE_H_ --#define TENSORFLOW_LITE_SUPPORT_SCANN_ONDEVICE_CC_MEM_WRITABLE_FILE_H_ -- --#include <memory> --#include <string> -- --#include "absl/status/statusor.h" // from @com_google_absl --#include "absl/strings/cord.h" // from @com_google_absl --#include "leveldb/env.h" // from @com_google_leveldb --#include "leveldb/slice.h" // from @com_google_leveldb --#include "leveldb/status.h" // from @com_google_leveldb -- --namespace tflite { --namespace scann_ondevice { -- --// An implementation of LevelDB's WritableFile [1] that wraps an in-memory --// buffer. --// --// [1]: https://github.com/google/leveldb/blob/main/include/leveldb/env.h --class MemWritableFile : public leveldb::WritableFile { -- public: -- // Creates a MemWritableFile from a given buffer. Returns -- // InvalidArgumentError if pointer is null. -- static absl::StatusOr<std::unique_ptr<MemWritableFile>> Create( -- std::string* buffer); -- -- ~MemWritableFile() override = default; -- -- // Allow moves. Disallow copies. -- MemWritableFile(MemWritableFile&& rhs) = default; -- MemWritableFile& operator=(MemWritableFile&& rhs) = default; -- MemWritableFile(const MemWritableFile& rhs) = delete; -- MemWritableFile& operator=(const MemWritableFile& rhs) = delete; -- -- leveldb::Status Append(const leveldb::Slice& data) override; -- leveldb::Status Close() override; -- leveldb::Status Flush() override; -- leveldb::Status Sync() override; -- -- private: -- MemWritableFile(std::string* buffer); -- -- std::string* buffer_; --}; -- --} // namespace scann_ondevice --} // namespace tflite -- --#endif // TENSORFLOW_LITE_SUPPORT_SCANN_ONDEVICE_CC_MEM_WRITABLE_FILE_H_ -diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/python/index_builder_py_wrapper.cc b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/python/index_builder_py_wrapper.cc -deleted file mode 100644 -index 709564035ff1f..0000000000000 ---- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/python/index_builder_py_wrapper.cc -+++ /dev/null -@@ -1,64 +0,0 @@ --/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. -- --Licensed under the Apache License, Version 2.0 (the "License"); --you may not use this file except in compliance with the License. --You may obtain a copy of the License at -- -- http://www.apache.org/licenses/LICENSE-2.0 -- --Unless required by applicable law or agreed to in writing, software --distributed under the License is distributed on an "AS IS" BASIS, --WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. --See the License for the specific language governing permissions and --limitations under the License. --==============================================================================*/ -- --#include <string> -- --#include "absl/types/optional.h" // from @com_google_absl --#include "absl/types/span.h" // from @com_google_absl --#include "pybind11/cast.h" --#include "pybind11/pybind11.h" --#include "pybind11/pytypes.h" --#include "pybind11_abseil/absl_casters.h" // from @pybind11_abseil --#include "pybind11_abseil/status_casters.h" // from @pybind11_abseil --#include "tensorflow_lite_support/scann_ondevice/cc/core/serialized_searcher.pb.h" --#include "tensorflow_lite_support/scann_ondevice/cc/index_builder.h" -- --namespace pybind11 { -- --PYBIND11_MODULE(index_builder, m) { -- google::ImportStatusModule(); -- -- m.def( -- "create_serialized_index_file", -- [](const uint32_t embedding_dim, const std::string& serialized_config, -- const std::string userinfo, -- absl::Span<const uint32_t> partition_assignment, -- absl::Span<const std::string> metadata, bool compression, -- absl::optional<absl::Span<const uint8_t>> hashed_database, -- absl::optional<absl::Span<const float>> float_database) -- -> absl::StatusOr<bytes> { -- tflite::scann_ondevice::core::ScannOnDeviceConfig config; -- config.ParseFromString(serialized_config); -- const auto status_or_bytes = tflite::scann_ondevice::CreateIndexBuffer( -- {.config = config, -- .embedding_dim = embedding_dim, -- .hashed_database = hashed_database, -- .float_database = float_database, -- .partition_assignment = partition_assignment, -- .metadata = metadata, -- .userinfo = userinfo}, -- compression); -- if (!status_or_bytes.ok()) { -- return status_or_bytes.status(); -- } -- return bytes(status_or_bytes.value()); -- }, -- arg("embedding_dim"), arg("serialized_config"), arg("userinfo"), -- arg("partition_assignment"), arg("metadata"), arg("compression") = true, -- arg("hashed_database") = absl::nullopt, -- arg("float_database") = absl::nullopt); --} -- --} // namespace pybind11 -diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/index_builder_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/index_builder_test.cc -deleted file mode 100644 -index 68830a9976e41..0000000000000 ---- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/index_builder_test.cc -+++ /dev/null -@@ -1,363 +0,0 @@ --/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. -- --Licensed under the Apache License, Version 2.0 (the "License"); --you may not use this file except in compliance with the License. --You may obtain a copy of the License at -- -- http://www.apache.org/licenses/LICENSE-2.0 -- --Unless required by applicable law or agreed to in writing, software --distributed under the License is distributed on an "AS IS" BASIS, --WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. --See the License for the specific language governing permissions and --limitations under the License. --==============================================================================*/ -- --#include "tensorflow_lite_support/scann_ondevice/cc/index_builder.h" -- --#include <cstdint> --#include <string> -- --#include "absl/flags/flag.h" // from @com_google_absl --#include "absl/memory/memory.h" // from @com_google_absl --#include "absl/status/status.h" // from @com_google_absl --#include "absl/strings/str_format.h" // from @com_google_absl --#include "absl/strings/string_view.h" // from @com_google_absl --#include "absl/types/span.h" // from @com_google_absl --#include "leveldb/env.h" // from @com_google_leveldb --#include "leveldb/iterator.h" // from @com_google_leveldb --#include "leveldb/options.h" // from @com_google_leveldb --#include "leveldb/slice.h" // from @com_google_leveldb --#include "leveldb/status.h" // from @com_google_leveldb --#include "leveldb/table.h" // from @com_google_leveldb --#include "tensorflow_lite_support/cc/port/gmock.h" --#include "tensorflow_lite_support/cc/port/gtest.h" --#include "tensorflow_lite_support/cc/port/status_matchers.h" --#include "tensorflow_lite_support/cc/test/message_matchers.h" --#include "tensorflow_lite_support/cc/test/test_utils.h" --#include "tensorflow_lite_support/scann_ondevice/proto/index_config.pb.h" -- --namespace tflite { --namespace scann_ondevice { --namespace { -- --using ::testing::Bool; --using ::testing::ElementsAreArray; --using ::testing::TestWithParam; --using ::tflite::support::EqualsProto; --using ::tflite::task::ParseTextProtoOrDie; -- --absl::Status SetContents(absl::string_view file_name, -- absl::string_view content) { -- FILE* fp = fopen(file_name.data(), "w"); -- if (fp == NULL) { -- return absl::InternalError( -- absl::StrFormat("Can't open file: %s", file_name)); -- } -- -- fwrite(content.data(), sizeof(char), content.size(), fp); -- size_t write_error = ferror(fp); -- if (fclose(fp) != 0 || write_error) { -- return absl::InternalError( -- absl::StrFormat("Error while writing file: %s. Error message: %s", -- file_name, strerror(write_error))); -- } -- return absl::OkStatus(); --} -- --absl::StatusOr<std::string> LookupKey(leveldb::Iterator* iterator, -- absl::string_view key) { -- iterator->Seek({key.data(), key.size()}); -- if (!iterator->Valid() || iterator->key().ToString() != key || -- !iterator->status().ok()) { -- return absl::NotFoundError("Failed to lookup key"); -- } -- return iterator->value().ToString(); --} -- --constexpr size_t kDimensions = 2; --constexpr size_t kNumEmbeddings = 24; --constexpr size_t kNumPartitions = 12; -- --IndexConfig CreateExpectedConfig(IndexConfig::Type embedding_type) { -- IndexConfig config = ParseTextProtoOrDie<IndexConfig>(R"pb( -- scann_config { -- partitioner { -- leaf { dimension: 0 dimension: 0 } -- leaf { dimension: 1 dimension: 1 } -- leaf { dimension: 2 dimension: 2 } -- leaf { dimension: 3 dimension: 3 } -- leaf { dimension: 4 dimension: 4 } -- leaf { dimension: 5 dimension: 5 } -- leaf { dimension: 6 dimension: 6 } -- leaf { dimension: 7 dimension: 7 } -- leaf { dimension: 8 dimension: 8 } -- leaf { dimension: 9 dimension: 9 } -- leaf { dimension: 10 dimension: 10 } -- leaf { dimension: 11 dimension: 11 } -- } -- } -- embedding_dim: 2 -- embedding_type: UINT8 -- global_partition_offsets: 0 -- global_partition_offsets: 2 -- global_partition_offsets: 4 -- global_partition_offsets: 6 -- global_partition_offsets: 8 -- global_partition_offsets: 10 -- global_partition_offsets: 12 -- global_partition_offsets: 14 -- global_partition_offsets: 16 -- global_partition_offsets: 18 -- global_partition_offsets: 20 -- global_partition_offsets: 22 -- )pb"); -- config.set_embedding_type(embedding_type); -- return config; --} -- --class PopulateIndexFileTest : public TestWithParam<bool /*compression*/> {}; -- --TEST_P(PopulateIndexFileTest, WritesHashedDatabase) { -- const std::string db_path = -- tflite::task::JoinPath(getenv("TEST_TMPDIR"), "hashed"); -- const bool compression = GetParam(); -- -- { -- tflite::scann_ondevice::core::ScannOnDeviceConfig config = -- ParseTextProtoOrDie<tflite::scann_ondevice::core::ScannOnDeviceConfig>( -- R"pb( -- partitioner: { -- leaf { dimension: 0 dimension: 0 } -- leaf { dimension: 1 dimension: 1 } -- leaf { dimension: 2 dimension: 2 } -- leaf { dimension: 3 dimension: 3 } -- leaf { dimension: 4 dimension: 4 } -- leaf { dimension: 5 dimension: 5 } -- leaf { dimension: 6 dimension: 6 } -- leaf { dimension: 7 dimension: 7 } -- leaf { dimension: 8 dimension: 8 } -- leaf { dimension: 9 dimension: 9 } -- leaf { dimension: 10 dimension: 10 } -- leaf { dimension: 11 dimension: 11 } -- } -- )pb"); -- std::vector<uint8_t> hashed_database; -- hashed_database.reserve(kNumEmbeddings * kDimensions); -- for (int i = 0; i < kNumEmbeddings; ++i) { -- for (int j = 0; j < kDimensions; ++j) { -- hashed_database.push_back(i); -- } -- } -- std::vector<uint32_t> partition_assignment; -- partition_assignment.reserve(kNumEmbeddings); -- for (int i = 0; i < kNumEmbeddings; ++i) { -- partition_assignment.push_back(i % kNumPartitions); -- } -- std::vector<std::string> metadata; -- metadata.reserve(kNumEmbeddings); -- for (int i = 0; i < kNumEmbeddings; ++i) { -- metadata.push_back(absl::StrFormat("%d", i)); -- } -- SUPPORT_ASSERT_OK_AND_ASSIGN( -- const std::string buffer, -- CreateIndexBuffer( -- {.config = config, -- .embedding_dim = kDimensions, -- .hashed_database = absl::Span<uint8_t>(hashed_database), -- .partition_assignment = absl::Span<uint32_t>(partition_assignment), -- .metadata = absl::Span<std::string>(metadata), -- .userinfo = "hashed_userinfo"}, -- compression)); -- SUPPORT_ASSERT_OK(SetContents(db_path, buffer)); -- } -- -- auto* env = leveldb::Env::Default(); -- leveldb::RandomAccessFile* hash_file; -- size_t hash_file_size; -- ASSERT_TRUE(env->NewRandomAccessFile(db_path, &hash_file).ok()); -- auto hashed_file_unique = absl::WrapUnique(hash_file); -- ASSERT_TRUE(env->GetFileSize(db_path, &hash_file_size).ok()); -- -- leveldb::Options options; -- options.compression = -- compression ? leveldb::kSnappyCompression : leveldb::kNoCompression; -- -- leveldb::Table* hashed_table; -- ASSERT_TRUE( -- leveldb::Table::Open(options, hash_file, hash_file_size, &hashed_table) -- .ok()); -- auto hashed_table_unique = absl::WrapUnique(hashed_table); -- auto hashed_table_iterator = -- absl::WrapUnique(hashed_table->NewIterator(leveldb::ReadOptions())); -- -- SUPPORT_ASSERT_OK_AND_ASSIGN( -- std::string serialized_config, -- LookupKey(hashed_table_iterator.get(), "INDEX_CONFIG")); -- IndexConfig index_config; -- EXPECT_TRUE(index_config.ParseFromString(serialized_config)); -- EXPECT_THAT(index_config, -- EqualsProto(CreateExpectedConfig(IndexConfig::UINT8))); -- -- SUPPORT_ASSERT_OK_AND_ASSIGN( -- std::string userinfo, -- LookupKey(hashed_table_iterator.get(), "USER_INFO")); -- EXPECT_EQ(userinfo, "hashed_userinfo"); -- -- // Partition assignment is based on i % kNumPartitions, so: -- // * partition 0 contains embeddings 0 and 12, -- // * partition 1 contains embeddings 1 and 13, -- // * etc -- for (int i = 0; i < kNumPartitions; ++i) { -- SUPPORT_ASSERT_OK_AND_ASSIGN( -- std::string raw_partition_hashed, -- LookupKey(hashed_table_iterator.get(), absl::StrFormat("E_%d", i))); -- std::vector<char> hashed_partition(raw_partition_hashed.begin(), -- raw_partition_hashed.end()); -- std::vector<char> expected = {static_cast<char>(i), static_cast<char>(i), -- static_cast<char>(i + kNumPartitions), -- static_cast<char>(i + kNumPartitions)}; -- EXPECT_THAT(hashed_partition, ElementsAreArray(expected)); -- } -- -- // Similarly: -- // * metadata 0 contains metadata 0, -- // * metadata 1 contains metadata 12, -- // * metadata 2 contains metadata 1, -- // * metadata 3 contains metadata 13, -- // * etc -- // Hence the `i / 2 + (i % 2 ? kNumPartitions : 0)` formula here. -- for (int i = 0; i < kNumEmbeddings; ++i) { -- SUPPORT_ASSERT_OK_AND_ASSIGN( -- std::string metadata, -- LookupKey(hashed_table_iterator.get(), absl::StrFormat("M_%d", i))); -- EXPECT_EQ(metadata, -- absl::StrFormat("%d", i / 2 + (i % 2 ? kNumPartitions : 0))); -- } --} -- --TEST_P(PopulateIndexFileTest, WritesFloatDatabase) { -- const std::string db_path = -- tflite::task::JoinPath(getenv("TEST_TMPDIR"), "float"); -- const bool compression = GetParam(); -- -- { -- tflite::scann_ondevice::core::ScannOnDeviceConfig config = -- ParseTextProtoOrDie<tflite::scann_ondevice::core::ScannOnDeviceConfig>( -- R"pb( -- partitioner: { -- leaf { dimension: 0 dimension: 0 } -- leaf { dimension: 1 dimension: 1 } -- leaf { dimension: 2 dimension: 2 } -- leaf { dimension: 3 dimension: 3 } -- leaf { dimension: 4 dimension: 4 } -- leaf { dimension: 5 dimension: 5 } -- leaf { dimension: 6 dimension: 6 } -- leaf { dimension: 7 dimension: 7 } -- leaf { dimension: 8 dimension: 8 } -- leaf { dimension: 9 dimension: 9 } -- leaf { dimension: 10 dimension: 10 } -- leaf { dimension: 11 dimension: 11 } -- } -- )pb"); -- std::vector<float> float_database; -- float_database.reserve(kNumEmbeddings * kDimensions); -- for (int i = 0; i < kNumEmbeddings; ++i) { -- for (int j = 0; j < kDimensions; ++j) { -- float_database.push_back(i); -- } -- } -- std::vector<uint32_t> partition_assignment; -- partition_assignment.reserve(kNumEmbeddings); -- for (int i = 0; i < kNumEmbeddings; ++i) { -- partition_assignment.push_back(i % kNumPartitions); -- } -- std::vector<std::string> metadata; -- metadata.reserve(kNumEmbeddings); -- for (int i = 0; i < kNumEmbeddings; ++i) { -- metadata.push_back(absl::StrFormat("%d", i)); -- } -- SUPPORT_ASSERT_OK_AND_ASSIGN( -- const std::string buffer, -- CreateIndexBuffer( -- {.config = config, -- .embedding_dim = kDimensions, -- .float_database = absl::Span<float>(float_database), -- .partition_assignment = absl::Span<uint32_t>(partition_assignment), -- .metadata = absl::Span<std::string>(metadata), -- .userinfo = "float_userinfo"}, -- compression)); -- SUPPORT_ASSERT_OK(SetContents(db_path, buffer)); -- } -- -- auto* env = leveldb::Env::Default(); -- leveldb::RandomAccessFile* float_file; -- size_t float_file_size; -- ASSERT_TRUE(env->NewRandomAccessFile(db_path, &float_file).ok()); -- auto float_file_unique = absl::WrapUnique(float_file); -- ASSERT_TRUE(env->GetFileSize(db_path, &float_file_size).ok()); -- -- leveldb::Options options; -- options.compression = -- compression ? leveldb::kSnappyCompression : leveldb::kNoCompression; -- -- leveldb::Table* float_table; -- ASSERT_TRUE( -- leveldb::Table::Open(options, float_file, float_file_size, &float_table) -- .ok()); -- auto float_table_unique = absl::WrapUnique(float_table); -- auto float_table_iterator = -- absl::WrapUnique(float_table->NewIterator(leveldb::ReadOptions())); -- -- SUPPORT_ASSERT_OK_AND_ASSIGN( -- std::string serialized_config, -- LookupKey(float_table_iterator.get(), "INDEX_CONFIG")); -- IndexConfig index_config; -- EXPECT_TRUE(index_config.ParseFromString(serialized_config)); -- EXPECT_THAT(index_config, -- EqualsProto(CreateExpectedConfig(IndexConfig::FLOAT))); -- -- SUPPORT_ASSERT_OK_AND_ASSIGN( -- std::string userinfo, LookupKey(float_table_iterator.get(), "USER_INFO")); -- EXPECT_EQ(userinfo, "float_userinfo"); -- -- // Partition assignment is based on i % kNumPartitions, so: -- // * partition 0 contains embeddings 0 and 12, -- // * partition 1 contains embeddings 1 and 13, -- // * etc -- for (int i = 0; i < kNumPartitions; ++i) { -- SUPPORT_ASSERT_OK_AND_ASSIGN( -- std::string raw_partition_float, -- LookupKey(float_table_iterator.get(), absl::StrFormat("E_%d", i))); -- const float* raw_partition_float_ptr = -- reinterpret_cast<const float*>(raw_partition_float.data()); -- std::vector<float> float_partition( -- raw_partition_float_ptr, -- raw_partition_float_ptr + raw_partition_float.size() / sizeof(float)); -- std::vector<float> expected = {static_cast<float>(i), static_cast<float>(i), -- static_cast<float>(i + kNumPartitions), -- static_cast<float>(i + kNumPartitions)}; -- EXPECT_THAT(float_partition, ElementsAreArray(expected)); -- } -- -- // Similarly: -- // * metadata 0 contains metadata 0, -- // * metadata 1 contains metadata 12, -- // * metadata 2 contains metadata 1, -- // * metadata 3 contains metadata 13, -- // * etc -- // Hence the `i / 2 + (i % 2 ? kNumPartitions : 0)` formula here. -- for (int i = 0; i < kNumEmbeddings; ++i) { -- SUPPORT_ASSERT_OK_AND_ASSIGN( -- std::string metadata, -- LookupKey(float_table_iterator.get(), absl::StrFormat("M_%d", i))); -- EXPECT_EQ(metadata, -- absl::StrFormat("%d", i / 2 + (i % 2 ? kNumPartitions : 0))); -- } --} -- --INSTANTIATE_TEST_SUITE_P(PopulateIndexFileTest, PopulateIndexFileTest, Bool()); -- --} // namespace --} // namespace scann_ondevice --} // namespace tflite -diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/mem_random_access_file_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/mem_random_access_file_test.cc -deleted file mode 100644 -index 2d1efb10e2b6c..0000000000000 ---- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/mem_random_access_file_test.cc -+++ /dev/null -@@ -1,64 +0,0 @@ --/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. -- --Licensed under the Apache License, Version 2.0 (the "License"); --you may not use this file except in compliance with the License. --You may obtain a copy of the License at -- -- http://www.apache.org/licenses/LICENSE-2.0 -- --Unless required by applicable law or agreed to in writing, software --distributed under the License is distributed on an "AS IS" BASIS, --WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. --See the License for the specific language governing permissions and --limitations under the License. --==============================================================================*/ -- --#include "tensorflow_lite_support/scann_ondevice/cc/mem_random_access_file.h" -- --#include "leveldb/slice.h" // from @com_google_leveldb --#include "tensorflow/lite/core/shims/cc/shims_test_util.h" --#include "tensorflow_lite_support/cc/port/gmock.h" --#include "tensorflow_lite_support/cc/port/gtest.h" -- --namespace tflite { --namespace scann_ondevice { --namespace { -- --constexpr char kBufferData[] = "abcdef"; --constexpr size_t kBufferSize = 6; -- --class MemRandomAccessFileTest : public tflite_shims::testing::Test { -- public: -- MemRandomAccessFileTest() : file_(kBufferData, kBufferSize) {} -- -- protected: -- MemRandomAccessFile file_; -- leveldb::Slice result_; --}; -- --TEST_F(MemRandomAccessFileTest, ReadFailsWithOutOfBoundsOffset) { -- EXPECT_TRUE(file_.Read(/*offset=*/7, /*n=*/1, &result_, /*scratch=*/nullptr) -- .IsInvalidArgument()); --} -- --TEST_F(MemRandomAccessFileTest, ReadSucceedsWithoutTruncation) { -- EXPECT_TRUE( -- file_.Read(/*offset=*/1, /*n=*/5, &result_, /*scratch=*/nullptr).ok()); -- EXPECT_EQ("bcdef", result_.ToString()); --} -- --TEST_F(MemRandomAccessFileTest, ReadSucceedsWithTruncation) { -- EXPECT_TRUE( -- file_.Read(/*offset=*/1, /*n=*/6, &result_, /*scratch=*/nullptr).ok()); -- EXPECT_EQ("bcdef", result_.ToString()); --} -- --TEST_F(MemRandomAccessFileTest, ReadSucceedsWithZeroLength) { -- EXPECT_TRUE( -- file_.Read(/*offset=*/1, /*n=*/0, &result_, /*scratch=*/nullptr).ok()); -- EXPECT_EQ("", result_.ToString()); --} -- --} // namespace --} // namespace scann_ondevice --} // namespace tflite -diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/python/leveldb_testing_utils_py_wrapper.cc b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/python/leveldb_testing_utils_py_wrapper.cc -deleted file mode 100644 -index 1b6906bf49bc2..0000000000000 ---- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/python/leveldb_testing_utils_py_wrapper.cc -+++ /dev/null -@@ -1,76 +0,0 @@ --/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. -- --Licensed under the Apache License, Version 2.0 (the "License"); --you may not use this file except in compliance with the License. --You may obtain a copy of the License at -- -- http://www.apache.org/licenses/LICENSE-2.0 -- --Unless required by applicable law or agreed to in writing, software --distributed under the License is distributed on an "AS IS" BASIS, --WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. --See the License for the specific language governing permissions and --limitations under the License. --==============================================================================*/ -- --#include <vector> -- --#include "absl/memory/memory.h" // from @com_google_absl --#include "absl/status/status.h" // from @com_google_absl --#include "absl/status/statusor.h" // from @com_google_absl --#include "absl/strings/str_format.h" // from @com_google_absl --#include "leveldb/env.h" // from @com_google_leveldb --#include "leveldb/options.h" // from @com_google_leveldb --#include "leveldb/table.h" // from @com_google_leveldb --#include "pybind11/cast.h" --#include "pybind11/pybind11.h" --#include "pybind11/pytypes.h" --#include "pybind11_abseil/absl_casters.h" // from @pybind11_abseil --#include "pybind11_abseil/status_casters.h" // from @pybind11_abseil -- --namespace pybind11 { -- --PYBIND11_MODULE(leveldb_testing_utils, m) { -- google::ImportStatusModule(); -- -- m.def( -- "leveldb_table_to_pair_list", -- [](const std::string fname, bool compressed) -- -> absl::StatusOr<std::vector<std::pair<bytes, bytes>>> { -- auto* env = leveldb::Env::Default(); -- leveldb::RandomAccessFile* file; -- if (!env->NewRandomAccessFile(fname, &file).ok()) { -- return absl::InternalError(absl::StrFormat( -- "Failed to create RandomAccessFile at %s", fname)); -- } -- auto unique_file = absl::WrapUnique(file); -- size_t file_size; -- if (!env->GetFileSize(fname, &file_size).ok()) { -- return absl::InternalError( -- absl::StrFormat("Failed to get file size at %s", fname)); -- } -- leveldb::Options options; -- options.compression = -- compressed ? leveldb::kSnappyCompression : leveldb::kNoCompression; -- -- leveldb::Table* table; -- if (!leveldb::Table::Open(options, file, file_size, &table).ok()) { -- return absl::InternalError("Failed to open table"); -- } -- auto unique_table = absl::WrapUnique(table); -- auto table_iterator = -- absl::WrapUnique(table->NewIterator(leveldb::ReadOptions())); -- table_iterator->SeekToFirst(); -- -- std::vector<std::pair<bytes, bytes>> result; -- for (; table_iterator->Valid(); table_iterator->Next()) { -- result.push_back( -- std::make_pair(bytes(table_iterator->key().ToString()), -- bytes(table_iterator->value().ToString()))); -- } -- return result; -- }, -- arg("buffer"), arg("compressed")); --} -- --} // namespace pybind11 -diff --git a/third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/common_win.bat b/third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/common_win.bat -deleted file mode 100644 -index 324a1953de706..0000000000000 ---- a/third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/common_win.bat -+++ /dev/null -@@ -1,29 +0,0 @@ --:: Copyright 2019 The TensorFlow Authors. All Rights Reserved. --:: --:: Licensed under the Apache License, Version 2.0 (the "License"); --:: you may not use this file except in compliance with the License. --:: You may obtain a copy of the License at --:: --:: http://www.apache.org/licenses/LICENSE-2.0 --:: --:: Unless required by applicable law or agreed to in writing, software --:: distributed under the License is distributed on an "AS IS" BASIS, --:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. --:: See the License for the specific language governing permissions and --:: limitations under the License. --:: ============================================================================= -- --:: This script is shamefully borrowed from: --:: //third_party/tensorflow/tools/ci_build/release/common_win.bat.oss -- --echo on -- --@REM --@REM Setup Bazel --@REM --:: Download Bazel from github and make sure its found in PATH. --SET BAZEL_VERSION=4.2.2 --md C:\tools\bazel\ --wget -q https://github.com/bazelbuild/bazel/releases/download/%BAZEL_VERSION%/bazel-%BAZEL_VERSION%-windows-x86_64.exe -O C:/tools/bazel/bazel.exe --SET PATH=C:\tools\bazel;%PATH% --bazel version --- -2.35.1.1178.g4f1659d476-goog -
diff --git a/third_party/tflite_support/patches/0006-run-clang-format.patch b/third_party/tflite_support/patches/0008-run-clang-format.patch similarity index 91% rename from third_party/tflite_support/patches/0006-run-clang-format.patch rename to third_party/tflite_support/patches/0008-run-clang-format.patch index d623fec..fbb051c2 100644 --- a/third_party/tflite_support/patches/0006-run-clang-format.patch +++ b/third_party/tflite_support/patches/0008-run-clang-format.patch
@@ -1,7 +1,7 @@ -From 7740a3260ec7b48f06fd658a22d5278ded7c323e Mon Sep 17 00:00:00 2001 +From 536821c33d55b5d714910c015008d2cebd7dfef5 Mon Sep 17 00:00:00 2001 From: Robert Ogden <robertogden@chromium.org> -Date: Wed, 13 Apr 2022 11:00:04 -0700 -Subject: [PATCH 6/9] run clang format +Date: Wed, 25 May 2022 11:03:46 -0700 +Subject: [PATCH 8/9] run clang format --- .../configuration/edgetpu_coral_plugin.cc | 20 +- @@ -10,6 +10,10 @@ .../src/tensorflow_lite_support/c/common.h | 4 +- .../tensorflow_lite_support/c/common_utils.cc | 11 +- .../tensorflow_lite_support/c/common_utils.h | 3 +- + .../c/task/audio/audio_classifier.cc | 12 +- + .../c/task/audio/audio_classifier.h | 12 +- + .../c/task/audio/core/audio_buffer.h | 4 +- + .../c/task/processor/classification_result.cc | 2 +- .../c/task/text/bert_nl_classifier.cc | 6 +- .../c/task/text/bert_nl_classifier.h | 6 +- .../c/task/text/bert_question_answerer.cc | 3 +- @@ -22,6 +26,7 @@ .../c/task/vision/image_segmenter.h | 6 +- .../c/task/vision/object_detector.cc | 6 +- .../c/task/vision/object_detector.h | 6 +- + .../test/task/audio/audio_classifier_test.cc | 32 +- .../test/task/vision/image_classifier_test.cc | 84 +- .../test/task/vision/image_segmenter_test.cc | 62 +- .../test/task/vision/object_detector_test.cc | 90 +- @@ -43,8 +48,8 @@ .../cc/task/core/base_task_api.h | 2 +- .../cc/task/core/classification_head.h | 2 +- .../cc/task/core/error_reporter.cc | 8 +- - .../cc/task/core/external_file_handler.cc | 4 +- - .../cc/task/core/external_file_handler.h | 2 +- + .../cc/task/core/external_file_handler.cc | 7 +- + .../cc/task/core/external_file_handler.h | 3 +- .../cc/task/core/label_map_item.cc | 5 +- .../cc/task/core/label_map_item.h | 7 +- .../cc/task/core/score_calibration.cc | 8 +- @@ -60,7 +65,7 @@ .../cc/task/processor/processor.h | 5 +- .../cc/task/processor/regex_preprocessor.cc | 3 +- .../cc/task/processor/regex_preprocessor.h | 3 +- - .../cc/task/processor/search_postprocessor.cc | 34 +- + .../cc/task/processor/search_postprocessor.cc | 40 +- .../cc/task/processor/search_postprocessor.h | 37 +- .../cc/task/text/bert_clu_annotator.cc | 4 +- .../cc/task/text/bert_nl_classifier.cc | 3 +- @@ -110,7 +115,9 @@ .../task/vision/utils/frame_buffer_utils.cc | 50 +- .../cc/task/vision/utils/frame_buffer_utils.h | 40 +- .../utils/frame_buffer_utils_interface.h | 11 +- - .../vision/utils/libyuv_frame_buffer_utils.cc | 79 +- + .../cc/task/vision/utils/image_utils.cc | 12 +- + .../cc/task/vision/utils/image_utils.h | 2 +- + .../vision/utils/libyuv_frame_buffer_utils.cc | 81 +- .../vision/utils/libyuv_frame_buffer_utils.h | 9 +- .../cc/task/vision/utils/score_calibration.cc | 8 +- .../cc/task/vision/utils/score_calibration.h | 11 +- @@ -121,10 +128,12 @@ .../test/task/text/clu_lib/bert_utils_test.cc | 32 +- .../task/text/clu_lib/intent_repr_test.cc | 2 +- .../text/nlclassifier/nl_classifier_test.cc | 83 +- - .../cc/test/task/text/text_embedder_test.cc | 20 +- + .../cc/test/task/text/text_embedder_test.cc | 26 +- + .../cc/test/task/text/text_searcher_test.cc | 18 +- .../universal_sentence_encoder_qa_test.cc | 16 +- .../test/task/vision/image_classifier_test.cc | 158 +- .../test/task/vision/image_embedder_test.cc | 95 +- + .../test/task/vision/image_searcher_test.cc | 62 +- .../test/task/vision/image_segmenter_test.cc | 117 +- .../test/task/vision/object_detector_test.cc | 157 +- .../cc/test/test_utils.cc | 18 +- @@ -141,7 +150,7 @@ .../cc/utils/common_utils.cc | 3 +- .../cc/utils/common_utils.h | 3 +- .../cc/utils/jni_utils.cc | 7 +- - .../cc/utils/jni_utils.h | 8 +- + .../cc/utils/jni_utils.h | 9 +- .../codegen/android_java_generator.cc | 37 +- .../codegen/android_java_generator.h | 5 +- .../codegen/code_generator.cc | 3 +- @@ -176,50 +185,57 @@ .../text/desktop/bert_nl_classifier_demo.cc | 14 +- .../desktop/bert_question_answerer_demo.cc | 18 +- .../task/text/desktop/nl_classifier_demo.cc | 14 +- + .../task/text/desktop/text_embedder_demo.cc | 26 +- + .../task/text/desktop/text_searcher_demo.cc | 30 +- .../universal_sentence_encoder_qa_demo.cc | 17 +- .../vision/desktop/image_classifier_demo.cc | 34 +- .../vision/desktop/image_embedder_demo.cc | 30 +- .../vision/desktop/image_searcher_demo.cc | 30 +- .../vision/desktop/image_segmenter_demo.cc | 24 +- .../vision/desktop/object_detector_demo.cc | 40 +- - .../task/vision/desktop/utils/image_utils.cc | 12 +- - .../task/vision/desktop/utils/image_utils.h | 2 +- .../ios/sources/TFLCommon.h | 11 +- .../ios/sources/TFLCommonUtils.h | 32 +- .../ios/sources/TFLCommonUtils.m | 19 +- + .../task/audio/core/sources/TFLFloatBuffer.h | 18 +- + .../task/audio/core/sources/TFLFloatBuffer.m | 4 +- + .../task/audio/core/sources/TFLRingBuffer.h | 32 +- + .../task/audio/core/sources/TFLRingBuffer.m | 49 +- .../core/sources/TFLBaseOptions+Helpers.h | 2 +- .../ios/task/core/sources/TFLBaseOptions.h | 32 +- .../processor/sources/TFLCategory+Helpers.h | 2 +- .../processor/sources/TFLCategory+Helpers.m | 7 +- - .../ios/task/processor/sources/TFLCategory.h | 16 +- + .../ios/task/processor/sources/TFLCategory.h | 22 +- .../ios/task/processor/sources/TFLCategory.m | 4 +- .../TFLClassificationOptions+Helpers.h | 6 +- - .../TFLClassificationOptions+Helpers.m | 31 +- + .../TFLClassificationOptions+Helpers.m | 33 +- .../sources/TFLClassificationOptions.h | 9 +- .../sources/TFLClassificationResult+Helpers.h | 17 +- - .../sources/TFLClassificationResult+Helpers.m | 9 +- - .../sources/TFLClassificationResult.h | 35 +- - .../sources/TFLClassificationResult.m | 7 +- - .../sources/TFLDetectionResult+Helpers.h | 12 +- - .../processor/sources/TFLDetectionResult.h | 15 +- + .../sources/TFLClassificationResult+Helpers.m | 22 +- + .../sources/TFLClassificationResult.h | 79 +- + .../sources/TFLClassificationResult.m | 12 +- + .../sources/TFLDetectionResult+Helpers.h | 11 +- + .../sources/TFLDetectionResult+Helpers.m | 15 +- + .../processor/sources/TFLDetectionResult.h | 35 +- + .../processor/sources/TFLDetectionResult.m | 4 +- .../sources/TFLSegmentationResult+Helpers.h | 4 +- - .../sources/TFLSegmentationResult+Helpers.m | 33 +- - .../processor/sources/TFLSegmentationResult.h | 21 +- - .../processor/sources/TFLSegmentationResult.m | 16 +- + .../sources/TFLSegmentationResult+Helpers.m | 44 +- + .../processor/sources/TFLSegmentationResult.h | 65 +- + .../processor/sources/TFLSegmentationResult.m | 45 +- .../Sources/TFLBertNLClassifier.h | 21 +- .../nlclassifier/Sources/TFLNLClassifier.h | 47 +- .../text/qa/Sources/TFLBertQuestionAnswerer.h | 4 +- - .../task/vision/sources/TFLImageClassifier.h | 37 +- - .../task/vision/sources/TFLImageClassifier.m | 20 +- - .../task/vision/sources/TFLImageSegmenter.h | 49 +- - .../task/vision/sources/TFLImageSegmenter.m | 34 +- - .../task/vision/sources/TFLObjectDetector.h | 36 +- - .../task/vision/sources/TFLObjectDetector.m | 13 +- + .../task/vision/sources/TFLImageClassifier.h | 90 +- + .../task/vision/sources/TFLImageClassifier.m | 58 +- + .../task/vision/sources/TFLImageSegmenter.h | 62 +- + .../task/vision/sources/TFLImageSegmenter.m | 49 +- + .../task/vision/sources/TFLObjectDetector.h | 64 +- + .../task/vision/sources/TFLObjectDetector.m | 54 +- .../vision/utils/sources/GMLImage+Utils.h | 8 +- - .../vision/utils/sources/GMLImage+Utils.m | 231 ++- + .../vision/utils/sources/GMLImage+Utils.m | 225 +- + .../test/task/audio/core/TFLRingBufferTests.m | 171 +- .../TFLImageClassifierTests.m | 28 +- .../image_segmenter/TFLImageSegmenterTests.m | 64 +- - .../object_detector/TFLObjectDetectorTests.m | 1 - + .../object_detector/TFLObjectDetectorTests.m | 36 +- .../tokenizers/Sources/TFLBertTokenizer.h | 6 +- .../Sources/TFLSentencepieceTokenizer.h | 2 +- .../text/tokenizers/Sources/TFLTokenizer.h | 4 +- @@ -269,17 +285,22 @@ .../lite/task/core/BaseTaskApi.java | 122 +- .../lite/task/core/ComputeSettings.java | 48 +- .../lite/task/core/TaskJniUtils.java | 275 ++- + .../core/annotations/UsedByReflection.java | 2 +- .../core/vision/ImageProcessingOptions.java | 125 +- + .../lite/task/processor/NearestNeighbor.java | 53 +- + .../lite/task/processor/SearcherOptions.java | 114 +- .../text/nlclassifier/BertNLClassifier.java | 391 ++-- .../task/text/nlclassifier/NLClassifier.java | 568 ++--- .../task/text/qa/BertQuestionAnswerer.java | 394 ++-- .../lite/task/text/qa/QaAnswer.java | 60 +- .../lite/task/text/qa/QuestionAnswerer.java | 19 +- + .../lite/task/text/searcher/TextSearcher.java | 375 ++-- .../vision/classifier/Classifications.java | 25 +- .../vision/classifier/ImageClassifier.java | 882 ++++---- .../task/vision/core/BaseVisionTaskApi.java | 349 ++-- .../lite/task/vision/detector/Detection.java | 26 +- .../task/vision/detector/ObjectDetector.java | 873 ++++---- + .../task/vision/searcher/ImageSearcher.java | 605 +++--- .../task/vision/segmenter/ColoredLabel.java | 112 +- .../task/vision/segmenter/ImageSegmenter.java | 752 ++++--- .../task/vision/segmenter/OutputType.java | 202 +- @@ -321,19 +342,23 @@ .../bert/bert_nl_classifier_jni.cc | 23 +- .../text/nlclassifier/nl_classifier_jni.cc | 21 +- .../text/qa/bert_question_answerer_jni.cc | 24 +- + .../task/text/searcher/text_searcher_jni.cc | 36 +- .../vision/classifier/image_classifier_jni.cc | 27 +- .../vision/core/base_vision_task_api_jni.cc | 40 +- .../vision/detector/object_detector_jni.cc | 27 +- .../java/src/native/task/vision/jni_utils.cc | 30 +- .../java/src/native/task/vision/jni_utils.h | 28 +- + .../vision/searcher/image_searcher_jni.cc | 36 +- .../vision/segmenter/image_segmenter_jni.cc | 32 +- - .../metadata/cc/metadata_extractor.cc | 21 +- + .../metadata/cc/metadata_extractor.cc | 20 +- .../metadata/cc/metadata_extractor.h | 4 +- .../metadata/cc/metadata_populator.cc | 2 +- .../metadata/cc/metadata_populator.h | 7 +- .../metadata/cc/metadata_version.cc | 33 +- - .../metadata/cc/utils/zip_mem_file.cc | 20 +- - .../metadata/cc/utils/zip_mem_file.h | 8 +- + .../cc/utils/zip_readonly_mem_file.cc | 13 +- + .../metadata/cc/utils/zip_readonly_mem_file.h | 4 +- + .../cc/utils/zip_writable_mem_file.cc | 17 +- + .../metadata/cc/utils/zip_writable_mem_file.h | 4 +- .../flatbuffers_lib/flatbuffers_lib.cc | 2 +- .../support/metadata/BoundedInputStream.java | 138 +- .../support/metadata/ByteBufferChannel.java | 188 +- @@ -370,9 +395,10 @@ .../odml/image/MediaMlImageBuilderTest.java | 109 +- .../android/odml/image/TestImageCreator.java | 211 +- .../core/pybinds/_pywrap_audio_buffer.cc | 17 +- + .../audio/pybinds/_pywrap_audio_classifier.cc | 1 - .../audio/pybinds/_pywrap_audio_embedder.cc | 22 +- .../task/vision/core/pybinds/image_utils.cc | 4 +- - .../pybinds/_pywrap_image_classifier.cc | 18 +- + .../pybinds/_pywrap_image_classifier.cc | 16 +- .../vision/pybinds/_pywrap_image_segmenter.cc | 12 +- .../vision/pybinds/_pywrap_object_detector.cc | 13 +- .../scann_ondevice/cc/core/index_table_sum.h | 41 +- @@ -383,7 +409,6 @@ .../scann_ondevice/cc/core/partitioner.h | 5 +- .../scann_ondevice/cc/core/searcher.h | 29 +- .../scann_ondevice/cc/core/searcher_test.cc | 9 +- - .../cc/core/serialized_searcher.proto | 16 +- .../cc/core/top_n_amortized_constant.h | 12 +- .../scann_ondevice/cc/index.cc | 23 +- .../scann_ondevice/cc/index.h | 13 +- @@ -393,13 +418,13 @@ .../cc/mem_random_access_file.h | 8 +- .../scann_ondevice/cc/mem_writable_file.h | 8 +- .../cc/python/index_builder_py_wrapper.cc | 6 +- - .../cc/test/index_builder_test.cc | 107 +- + .../cc/test/index_builder_test.cc | 143 +- .../scann_ondevice/cc/test/index_test.cc | 33 +- .../cc/test/mem_writable_file_test.cc | 2 +- .../leveldb_testing_utils_py_wrapper.cc | 14 +- .../src/third_party/fft2d/fft.h | 12 +- .../src/third_party/fft2d/fft2d.h | 12 +- - 395 files changed, 18104 insertions(+), 17504 deletions(-) + 420 files changed, 19248 insertions(+), 18509 deletions(-) diff --git a/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/edgetpu_coral_plugin.cc b/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/edgetpu_coral_plugin.cc index 9f27f3baae82f..6a16d12856258 100644 @@ -444,7 +469,7 @@ } diff --git a/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/edgetpu_coral_plugin_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/edgetpu_coral_plugin_test.cc -index 83cb6f24b1277..cc183a65a9e5f 100644 +index a02635b9f3578..6ac4e5c734567 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/edgetpu_coral_plugin_test.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/edgetpu_coral_plugin_test.cc @@ -43,7 +43,8 @@ using ::tflite::task::vision::ImageDataFree; @@ -536,6 +561,117 @@ // Creates a TfLiteSupportError from absl::Status and passes it back as a // parameter which is a pointer to the error pointer. +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/audio_classifier.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/audio_classifier.cc +index 89fba26b9b72f..3f1781a0a7db8 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/audio_classifier.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/audio_classifier.cc +@@ -109,7 +109,8 @@ TfLiteAudioClassifierOptions TfLiteAudioClassifierOptionsCreate(void) { + } + + TfLiteAudioClassifier* TfLiteAudioClassifierFromOptions( +- const TfLiteAudioClassifierOptions* options, TfLiteSupportError** error) { ++ const TfLiteAudioClassifierOptions* options, ++ TfLiteSupportError** error) { + StatusOr<AudioClassifierOptionsCpp> cpp_option_status = + CreateAudioClassifierCppOptionsFromCOptions(options); + +@@ -181,7 +182,8 @@ TfLiteClassificationResult* GetClassificationResultCStruct( + + TfLiteClassificationResult* TfLiteAudioClassifierClassify( + const TfLiteAudioClassifier* classifier, +- const TfLiteAudioBuffer* audio_buffer, TfLiteSupportError** error) { ++ const TfLiteAudioBuffer* audio_buffer, ++ TfLiteSupportError** error) { + if (classifier == nullptr) { + tflite::support::CreateTfLiteSupportError( + kInvalidArgumentError, "Expected non null audio classifier.", error); +@@ -211,7 +213,8 @@ TfLiteClassificationResult* TfLiteAudioClassifierClassify( + } + + int TfLiteAudioClassifierGetRequiredInputBufferSize( +- TfLiteAudioClassifier* classifier, TfLiteSupportError** error) { ++ TfLiteAudioClassifier* classifier, ++ TfLiteSupportError** error) { + if (classifier == nullptr) { + tflite::support::CreateTfLiteSupportError( + kInvalidArgumentError, "Expected non null audio classifier.", error); +@@ -226,7 +229,8 @@ void TfLiteAudioClassifierDelete(TfLiteAudioClassifier* classifier) { + } + + TfLiteAudioFormat* TfLiteAudioClassifierGetRequiredAudioFormat( +- TfLiteAudioClassifier* classifier, TfLiteSupportError** error) { ++ TfLiteAudioClassifier* classifier, ++ TfLiteSupportError** error) { + if (classifier == nullptr) { + tflite::support::CreateTfLiteSupportError( + kInvalidArgumentError, "Expected non null audio classifier.", error); +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/audio_classifier.h b/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/audio_classifier.h +index e83295963378c..6af9b27944744 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/audio_classifier.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/audio_classifier.h +@@ -157,7 +157,8 @@ TfLiteAudioClassifierOptions TfLiteAudioClassifierOptionsCreate(void); + // TfLiteSupportErrorDelete(error) + // + TfLiteAudioClassifier* TfLiteAudioClassifierFromOptions( +- const TfLiteAudioClassifierOptions* options, TfLiteSupportError** error); ++ const TfLiteAudioClassifierOptions* options, ++ TfLiteSupportError** error); + + // Invokes the encapsulated TFLite model and classifies the frame_buffer. + // Returns a pointer to the created classification result in case of success or +@@ -185,15 +186,18 @@ TfLiteAudioClassifier* TfLiteAudioClassifierFromOptions( + // + TfLiteClassificationResult* TfLiteAudioClassifierClassify( + const TfLiteAudioClassifier* classifier, +- const TfLiteAudioBuffer* audio_buffer, TfLiteSupportError** error); ++ const TfLiteAudioBuffer* audio_buffer, ++ TfLiteSupportError** error); + + // Returns the input buffer size required by the audio classifier. + int TfLiteAudioClassifierGetRequiredInputBufferSize( +- TfLiteAudioClassifier* classifier, TfLiteSupportError** error); ++ TfLiteAudioClassifier* classifier, ++ TfLiteSupportError** error); + + // Returns the audio format required by the audio classifier. + TfLiteAudioFormat* TfLiteAudioClassifierGetRequiredAudioFormat( +- TfLiteAudioClassifier* classifier, TfLiteSupportError** error); ++ TfLiteAudioClassifier* classifier, ++ TfLiteSupportError** error); + + // Disposes off the audio classifier. + void TfLiteAudioClassifierDelete(TfLiteAudioClassifier* classifier); +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/core/audio_buffer.h b/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/core/audio_buffer.h +index 2ec7571036d29..471f02fdf2132 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/core/audio_buffer.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/core/audio_buffer.h +@@ -45,11 +45,11 @@ typedef struct TfLiteAudioBuffer { + int size; + } TfLiteAudioBuffer; + +-void TfLiteAudioBufferDelete(TfLiteAudioBuffer *buffer); ++void TfLiteAudioBufferDelete(TfLiteAudioBuffer* buffer); + + void TfLiteAudioBufferDeleteData(const TfLiteAudioBuffer audio_buffer); + +-void TfLiteAudioFormatDelete(TfLiteAudioFormat *format); ++void TfLiteAudioFormatDelete(TfLiteAudioFormat* format); + + #ifdef __cplusplus + } // extern "C" +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/processor/classification_result.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/task/processor/classification_result.cc +index 646e2c237c2f8..b7d7fab827694 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/processor/classification_result.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/processor/classification_result.cc +@@ -27,7 +27,7 @@ void TfLiteClassificationResultDelete( + for (int head = 0; head < classification_result->size; ++head) { + TfLiteClassifications classifications = + classification_result->classifications[head]; +- free(classifications.head_name); ++ free(classifications.head_name); + for (int rank = 0; rank < classifications.size; ++rank) { + TfLiteCategoryDelete(&(classifications.categories[rank])); + } diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_nl_classifier.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_nl_classifier.cc index 26888a832fc34..52907f4fe7d35 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_nl_classifier.cc @@ -641,7 +777,7 @@ // Invokes the encapsulated TFLite model and classifies the input text. Categories* TfLiteNLClassifierClassify(const TfLiteNLClassifier* classifier, diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_classifier.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_classifier.cc -index ecf519c61c368..8d9aa850100c5 100644 +index 52e215116b51e..183468a6855aa 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_classifier.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_classifier.cc @@ -110,7 +110,8 @@ TfLiteImageClassifierOptions TfLiteImageClassifierOptionsCreate(void) { @@ -654,7 +790,7 @@ StatusOr<ImageClassifierOptionsCpp> cpp_option_status = CreateImageClassifierCppOptionsFromCOptions(options); -@@ -177,7 +178,8 @@ TfLiteClassificationResult* GetClassificationResultCStruct( +@@ -178,7 +179,8 @@ TfLiteClassificationResult* GetClassificationResultCStruct( TfLiteClassificationResult* TfLiteImageClassifierClassifyWithRoi( const TfLiteImageClassifier* classifier, @@ -664,7 +800,7 @@ TfLiteSupportError** error) { if (classifier == nullptr) { tflite::support::CreateTfLiteSupportError( -@@ -220,7 +222,8 @@ TfLiteClassificationResult* TfLiteImageClassifierClassifyWithRoi( +@@ -221,7 +223,8 @@ TfLiteClassificationResult* TfLiteImageClassifierClassifyWithRoi( TfLiteClassificationResult* TfLiteImageClassifierClassify( const TfLiteImageClassifier* classifier, @@ -804,8 +940,93 @@ TfLiteSupportError** error); // Disposes off the object detector. +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/audio/audio_classifier_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/audio/audio_classifier_test.cc +index 17b2a4ccede29..126784cf6c755 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/audio/audio_classifier_test.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/audio/audio_classifier_test.cc +@@ -45,9 +45,10 @@ constexpr char kYamNetAudioClassifierWithMetadata[] = + "yamnet_audio_classifier_with_metadata.tflite"; + + StatusOr<TfLiteAudioBuffer> LoadAudioBufferFromFileNamed( +- const std::string wav_file, int buffer_size) { +- std::string contents = ReadFile( +- JoinPath("./" /*test src dir*/, kTestDataDirectory, wav_file)); ++ const std::string wav_file, ++ int buffer_size) { ++ std::string contents = ++ ReadFile(JoinPath("./" /*test src dir*/, kTestDataDirectory, wav_file)); + + uint32_t decoded_sample_count; + uint16_t decoded_channel_count; +@@ -90,7 +91,8 @@ void Verify(TfLiteClassificationResult* classification_result, + } + + void Verify(TfLiteClassifications& classifications, +- int expected_categories_size, int expected_head_index, ++ int expected_categories_size, ++ int expected_head_index, + char const* expected_head_name) { + EXPECT_EQ(classifications.size, expected_categories_size); + EXPECT_EQ(classifications.head_index, expected_head_index); +@@ -101,8 +103,10 @@ void Verify(TfLiteClassifications& classifications, + EXPECT_NE(classifications.categories, nullptr); + } + +-void Verify(TfLiteCategory& category, int expected_index, +- char const* expected_label, float expected_score) { ++void Verify(TfLiteCategory& category, ++ int expected_index, ++ char const* expected_label, ++ float expected_score) { + const float kPrecision = 1e-6; + EXPECT_EQ(category.index, expected_index); + EXPECT_NE(category.label, nullptr); +@@ -115,7 +119,8 @@ void Verify(TfLiteCategory& category, int expected_index, + EXPECT_NEAR(category.score, expected_score, kPrecision); + } + +-void Verify(TfLiteSupportError* error, TfLiteSupportErrorCode error_code, ++void Verify(TfLiteSupportError* error, ++ TfLiteSupportErrorCode error_code, + char const* message) { + ASSERT_NE(error, nullptr); + EXPECT_EQ(error->code, kInvalidArgumentError); +@@ -133,7 +138,8 @@ TEST_F(AudioClassifierFromOptionsTest, FailsWithMissingModelPathAndError) { + TfLiteAudioClassifierFromOptions(&options, &error); + + EXPECT_EQ(audio_classifier, nullptr); +- if (audio_classifier) TfLiteAudioClassifierDelete(audio_classifier); ++ if (audio_classifier) ++ TfLiteAudioClassifierDelete(audio_classifier); + + Verify(error, kInvalidArgumentError, + "INVALID_ARGUMENT: Missing mandatory `model_file` field in " +@@ -143,9 +149,8 @@ TEST_F(AudioClassifierFromOptionsTest, FailsWithMissingModelPathAndError) { + } + + TEST_F(AudioClassifierFromOptionsTest, SucceedsWithModelPath) { +- std::string model_path = +- JoinPath("./" /*test src dir*/, kTestDataDirectory, +- kYamNetAudioClassifierWithMetadata); ++ std::string model_path = JoinPath("./" /*test src dir*/, kTestDataDirectory, ++ kYamNetAudioClassifierWithMetadata); + TfLiteAudioClassifierOptions options = TfLiteAudioClassifierOptionsCreate(); + options.base_options.model_file.file_path = model_path.data(); + TfLiteAudioClassifier* audio_classifier = +@@ -158,9 +163,8 @@ TEST_F(AudioClassifierFromOptionsTest, SucceedsWithModelPath) { + class AudioClassifierClassifyTest : public tflite_shims::testing::Test { + protected: + void SetUp() override { +- std::string model_path = +- JoinPath("./" /*test src dir*/, kTestDataDirectory, +- kYamNetAudioClassifierWithMetadata); ++ std::string model_path = JoinPath("./" /*test src dir*/, kTestDataDirectory, ++ kYamNetAudioClassifierWithMetadata); + + TfLiteAudioClassifierOptions options = TfLiteAudioClassifierOptionsCreate(); + options.base_options.model_file.file_path = model_path.data(); diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/image_classifier_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/image_classifier_test.cc -index 688af14580ab3..b398b7adafe5c 100644 +index 0a59344f4394c..cce2fa63fad17 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/image_classifier_test.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/image_classifier_test.cc @@ -44,8 +44,8 @@ constexpr char kMobileNetQuantizedWithMetadata[] = @@ -1044,7 +1265,7 @@ ASSERT_NE(classification_result, nullptr); EXPECT_GE(classification_result->size, 1); diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/image_segmenter_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/image_segmenter_test.cc -index f5e37f94e749a..81ade94585caf 100644 +index d4c8106b2729d..c03c15d6fe6b7 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/image_segmenter_test.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/image_segmenter_test.cc @@ -46,8 +46,8 @@ constexpr char kTestDataDirectory[] = @@ -1209,7 +1430,7 @@ int inconsistent_pixels = 0; int num_pixels = golden_mask.height * golden_mask.width; diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/object_detector_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/object_detector_test.cc -index f52623ada454c..99cd0034287ad 100644 +index 0171e584fdd3d..78d78f5ddb6d1 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/object_detector_test.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/object_detector_test.cc @@ -46,8 +46,8 @@ constexpr char kMobileSsdWithMetadata[] = @@ -1642,7 +1863,7 @@ MoveAssignBase() = default; MoveAssignBase(const MoveAssignBase&) = default; diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.cc -index d47c1ce7e5179..bb43d09f4a96b 100644 +index 11f9d584cfdd0..4d23efe43bc99 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.cc @@ -15,7 +15,7 @@ limitations under the License. @@ -1654,7 +1875,7 @@ #include "absl/strings/str_format.h" // from @com_google_absl #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/delegates/interpreter_utils.h" -@@ -304,7 +304,9 @@ absl::Status TfLiteInterpreterWrapper::InvokeWithoutFallback() { +@@ -310,7 +310,9 @@ absl::Status TfLiteInterpreterWrapper::InvokeWithoutFallback() { return absl::OkStatus(); } @@ -1665,7 +1886,7 @@ void TfLiteInterpreterWrapper::SetTfLiteCancellation() { // Create a cancellation check function and set to the TFLite interpreter. -@@ -317,7 +319,8 @@ void TfLiteInterpreterWrapper::SetTfLiteCancellation() { +@@ -323,7 +325,8 @@ void TfLiteInterpreterWrapper::SetTfLiteCancellation() { } absl::Status TfLiteInterpreterWrapper::LoadDelegatePlugin( @@ -1942,10 +2163,10 @@ } // namespace core } // namespace task diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/external_file_handler.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/external_file_handler.cc -index 5e17e14dc5f7a..b9ae32253cb29 100644 +index 9c4cc2009baea..e15830d5ab061 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/external_file_handler.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/external_file_handler.cc -@@ -24,11 +24,11 @@ limitations under the License. +@@ -18,11 +18,11 @@ limitations under the License. #include <memory> #include <string> @@ -1959,8 +2180,21 @@ namespace tflite { namespace task { +@@ -57,11 +57,10 @@ absl::Status ExternalFileHandler::MapExternalFile() { + StatusCode::kInvalidArgument, + "ExternalFile must specify 'file_content' in Chromium.", + TfLiteSupportStatus::kInvalidArgumentError); +- + } + + absl::string_view ExternalFileHandler::GetFileContent() { +- return external_file_.file_content(); ++ return external_file_.file_content(); + } + + ExternalFileHandler::~ExternalFileHandler() = default; diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/external_file_handler.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/external_file_handler.h -index e8b6831c6ad69..0b74e468d004f 100644 +index a7daa175f77f5..9f35fdd6d09ce 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/external_file_handler.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/external_file_handler.h @@ -18,7 +18,7 @@ limitations under the License. @@ -1972,6 +2206,14 @@ #include "absl/strings/string_view.h" // from @com_google_absl #include "tensorflow_lite_support/cc/port/integral_types.h" #include "tensorflow_lite_support/cc/port/statusor.h" +@@ -64,7 +64,6 @@ class ExternalFileHandler { + + // Reference to the input ExternalFile. + const ExternalFile& external_file_; +- + }; + + } // namespace core diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/label_map_item.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/label_map_item.cc index 694c55ab34e78..72e4b670cb172 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/label_map_item.cc @@ -2180,10 +2422,10 @@ return FindTensorIndexByModelName(tensors, model_tensor_name); diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.cc -index e0f69cd1c80ac..2794290a2411e 100644 +index 5999090cab973..41e06389af80b 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.cc -@@ -19,7 +19,7 @@ limitations under the License. +@@ -17,7 +17,7 @@ limitations under the License. #include <memory> @@ -2192,7 +2434,7 @@ #include "absl/strings/str_cat.h" // from @com_google_absl #include "tensorflow/lite/builtin_ops.h" #include "tensorflow/lite/core/shims/cc/kernels/register.h" -@@ -40,7 +40,8 @@ using ::tflite::support::CreateStatusWithPayload; +@@ -38,7 +38,8 @@ using ::tflite::support::CreateStatusWithPayload; using ::tflite::support::InterpreterCreationResources; using ::tflite::support::TfLiteSupportStatus; @@ -2202,7 +2444,7 @@ tflite::ErrorReporter* reporter) { return tflite_shims::Verify(data, length, reporter); } -@@ -71,7 +72,8 @@ std::vector<const TfLiteTensor*> TfLiteEngine::GetOutputs() { +@@ -69,7 +70,8 @@ std::vector<const TfLiteTensor*> TfLiteEngine::GetOutputs() { } void TfLiteEngine::VerifyAndBuildModelFromBuffer( @@ -2212,7 +2454,7 @@ TfLiteVerifier* extra_verifier) { model_ = tflite_shims::FlatBufferModel::VerifyAndBuildFromBuffer( buffer_data, buffer_size, extra_verifier, &error_reporter_); -@@ -118,7 +120,8 @@ absl::Status TfLiteEngine::InitializeFromModelFileHandler( +@@ -116,7 +118,8 @@ absl::Status TfLiteEngine::InitializeFromModelFileHandler( } absl::Status TfLiteEngine::BuildModelFromFlatBuffer( @@ -2222,7 +2464,7 @@ const tflite::proto::ComputeSettings& compute_settings) { if (model_) { return CreateStatusWithPayload(StatusCode::kInternal, -@@ -207,7 +210,8 @@ absl::Status TfLiteEngine::InitInterpreter(int num_threads) { +@@ -205,7 +208,8 @@ absl::Status TfLiteEngine::InitInterpreter(int num_threads) { // absl::Status TfLiteEngine::InitInterpreter( // const tflite::proto::ComputeSettings& compute_settings) absl::Status TfLiteEngine::InitInterpreter( @@ -2233,10 +2475,10 @@ settings_copy.mutable_tflite_settings() ->mutable_cpu_settings() diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.h -index 9b44c6e5c022a..1c6a067d6be9e 100644 +index 53dabdc4841d7..0cbaa738e6db6 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.h -@@ -20,8 +20,8 @@ limitations under the License. +@@ -18,8 +18,8 @@ limitations under the License. #include <memory> @@ -2247,7 +2489,7 @@ #include "absl/strings/string_view.h" // from @com_google_absl #include "tensorflow/lite/core/api/op_resolver.h" #include "tensorflow/lite/core/shims/c/common.h" -@@ -98,7 +98,8 @@ class TfLiteEngine { +@@ -96,7 +96,8 @@ class TfLiteEngine { // object. This performs extra verification on the input data using // tflite::Verify. absl::Status BuildModelFromFlatBuffer( @@ -2257,7 +2499,7 @@ const tflite::proto::ComputeSettings& compute_settings = tflite::proto::ComputeSettings()); -@@ -140,7 +141,8 @@ class TfLiteEngine { +@@ -138,7 +139,8 @@ class TfLiteEngine { // absl::Status TfLiteEngine::InitInterpreter( // const tflite::proto::ComputeSettings& compute_settings) absl::Status InitInterpreter( @@ -2267,7 +2509,7 @@ // Cancels the on-going `Invoke()` call if any and if possible. This method // can be called from a different thread than the one where `Invoke()` is -@@ -157,7 +159,8 @@ class TfLiteEngine { +@@ -155,7 +157,8 @@ class TfLiteEngine { // the FlatBuffer data provided as input. class Verifier : public tflite::TfLiteVerifier { public: @@ -2434,10 +2676,10 @@ absl::Status Preprocess(const std::string& text); diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/search_postprocessor.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/search_postprocessor.cc -index a357973d8f925..e3bc2688caf3a 100644 +index 730c9919cadee..a2fa1f8243199 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/search_postprocessor.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/search_postprocessor.cc -@@ -22,16 +22,11 @@ limitations under the License. +@@ -22,17 +22,12 @@ limitations under the License. #include <memory> #include <vector> @@ -2448,20 +2690,22 @@ -#include "tensorflow_lite_support/scann_ondevice/cc/core/top_n_amortized_constant.h" -#include "absl/memory/memory.h" // from @com_google_absl -#include "absl/status/status.h" // from @com_google_absl -+#include "Eigen/Core" // from @eigen -+#include "absl/memory/memory.h" // from @com_google_absl -+#include "absl/status/status.h" // from @com_google_absl - #include "absl/strings/str_format.h" // from @com_google_absl +-#include "absl/strings/str_format.h" // from @com_google_absl ++#include "Eigen/Core" // from @eigen ++#include "absl/memory/memory.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl ++#include "absl/strings/str_format.h" // from @com_google_absl + #include "absl/strings/string_view.h" // from @com_google_absl -#include "absl/types/span.h" // from @com_google_absl -#include "Eigen/Core" // from @eigen -+#include "absl/types/span.h" // from @com_google_absl ++#include "absl/types/span.h" // from @com_google_absl #include "tensorflow_lite_support/cc/common.h" #include "tensorflow_lite_support/cc/port/status_macros.h" #include "tensorflow_lite_support/cc/port/statusor.h" -@@ -42,6 +37,11 @@ limitations under the License. - #include "tensorflow_lite_support/cc/task/processor/proto/embedding_options.pb.h" - #include "tensorflow_lite_support/cc/task/processor/proto/search_options.pb.h" +@@ -45,6 +40,11 @@ limitations under the License. #include "tensorflow_lite_support/cc/task/processor/proto/search_result.pb.h" + #include "tensorflow_lite_support/metadata/cc/metadata_extractor.h" + #include "tensorflow_lite_support/metadata/metadata_schema_generated.h" +#include "tensorflow_lite_support/scann_ondevice/cc/core/partitioner.h" +#include "tensorflow_lite_support/scann_ondevice/cc/core/processor.h" +#include "tensorflow_lite_support/scann_ondevice/cc/core/searcher.h" @@ -2470,10 +2714,12 @@ #include "tensorflow_lite_support/scann_ondevice/cc/index.h" #include "tensorflow_lite_support/scann_ondevice/proto/index_config.pb.h" -@@ -53,14 +53,14 @@ namespace { +@@ -56,16 +56,16 @@ namespace { constexpr int kNoNeighborId = -1; ++using ::tflite::TensorMetadata; ++using ::tflite::metadata::ModelMetadataExtractor; +using ::tflite::scann_ondevice::Index; +using ::tflite::scann_ondevice::IndexConfig; using ::tflite::scann_ondevice::core::AsymmetricHashFindNeighbors; @@ -2482,12 +2728,14 @@ using ::tflite::scann_ondevice::core::QueryInfo; using ::tflite::scann_ondevice::core::ScannOnDeviceConfig; using ::tflite::scann_ondevice::core::TopN; +-using ::tflite::TensorMetadata; +-using ::tflite::metadata::ModelMetadataExtractor; -using ::tflite::scann_ondevice::Index; -using ::tflite::scann_ondevice::IndexConfig; using ::tflite::support::CreateStatusWithPayload; using ::tflite::support::StatusOr; using ::tflite::support::TfLiteSupportStatus; -@@ -191,7 +191,8 @@ absl::Status ConvertEmbeddingToEigenMatrix(const Embedding& embedding, +@@ -212,7 +212,8 @@ absl::Status ConvertEmbeddingToEigenMatrix(const Embedding& embedding, /* static */ StatusOr<std::unique_ptr<SearchPostprocessor>> SearchPostprocessor::Create( @@ -2497,7 +2745,7 @@ std::unique_ptr<SearchOptions> search_options, std::unique_ptr<EmbeddingOptions> embedding_options) { ASSIGN_OR_RETURN(auto embedding_postprocessor, -@@ -288,7 +289,8 @@ absl::Status SearchPostprocessor::Init( +@@ -316,7 +317,8 @@ absl::Status SearchPostprocessor::Init( index_config_.scann_config().partitioner().search_fraction())), partitioner_->NumPartitions()); } else { @@ -2507,7 +2755,7 @@ num_leaves_to_search_ = partitioner_->NumPartitions(); } -@@ -302,7 +304,8 @@ absl::Status SearchPostprocessor::Init( +@@ -330,7 +332,8 @@ absl::Status SearchPostprocessor::Init( } absl::Status SearchPostprocessor::QuantizedSearch( @@ -2517,7 +2765,7 @@ absl::Span<TopN> top_n) { int dim = index_config_.embedding_dim(); // Prepare QueryInfo used for all leaves. -@@ -332,7 +335,8 @@ absl::Status SearchPostprocessor::QuantizedSearch( +@@ -360,7 +363,8 @@ absl::Status SearchPostprocessor::QuantizedSearch( } absl::Status SearchPostprocessor::LinearSearch( @@ -4202,7 +4450,7 @@ - } diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.cc -index f30af1e7d27d8..cea7ef3fb1f23 100644 +index 1854cf546d599..9a5b96160c033 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.cc @@ -18,7 +18,7 @@ limitations under the License. @@ -4261,7 +4509,7 @@ if (!AreBufferFormatsCompatible(buffer, output_buffer)) { return absl::InvalidArgumentError( "Input and output buffer formats must match."); -@@ -314,8 +321,10 @@ absl::Status ValidateConvertFormats(FrameBuffer::Format from_format, +@@ -309,8 +316,10 @@ absl::Status ValidateConvertFormats(FrameBuffer::Format from_format, // Creates a FrameBuffer from raw RGBA buffer and passing arguments. std::unique_ptr<FrameBuffer> CreateFromRgbaRawBuffer( @@ -4274,7 +4522,7 @@ FrameBuffer::Stride stride) { if (stride == kDefaultStride) { stride.row_stride_bytes = dimension.width * kRgbaChannels; -@@ -330,8 +339,10 @@ std::unique_ptr<FrameBuffer> CreateFromRgbaRawBuffer( +@@ -325,8 +334,10 @@ std::unique_ptr<FrameBuffer> CreateFromRgbaRawBuffer( // Creates a FrameBuffer from raw RGB buffer and passing arguments. std::unique_ptr<FrameBuffer> CreateFromRgbRawBuffer( @@ -4287,7 +4535,7 @@ FrameBuffer::Stride stride) { if (stride == kDefaultStride) { stride.row_stride_bytes = dimension.width * kRgbChannels; -@@ -345,8 +356,10 @@ std::unique_ptr<FrameBuffer> CreateFromRgbRawBuffer( +@@ -340,8 +351,10 @@ std::unique_ptr<FrameBuffer> CreateFromRgbRawBuffer( // Creates a FrameBuffer from raw grayscale buffer and passing arguments. std::unique_ptr<FrameBuffer> CreateFromGrayRawBuffer( @@ -4300,7 +4548,7 @@ FrameBuffer::Stride stride) { if (stride == kDefaultStride) { stride.row_stride_bytes = dimension.width * kGrayChannel; -@@ -361,10 +374,16 @@ std::unique_ptr<FrameBuffer> CreateFromGrayRawBuffer( +@@ -356,10 +369,16 @@ std::unique_ptr<FrameBuffer> CreateFromGrayRawBuffer( // Creates a FrameBuffer from raw YUV buffer and passing arguments. StatusOr<std::unique_ptr<FrameBuffer>> CreateFromYuvRawBuffer( @@ -4321,7 +4569,7 @@ const int pixel_stride_y = 1; std::vector<FrameBuffer::Plane> planes; if (format == FrameBuffer::Format::kNV21 || -@@ -385,9 +404,11 @@ StatusOr<std::unique_ptr<FrameBuffer>> CreateFromYuvRawBuffer( +@@ -380,9 +399,11 @@ StatusOr<std::unique_ptr<FrameBuffer>> CreateFromYuvRawBuffer( } StatusOr<std::unique_ptr<FrameBuffer>> CreateFromRawBuffer( @@ -4699,11 +4947,55 @@ FrameBuffer* output_buffer) = 0; // Flips `buffer` horizontally. +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/image_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/image_utils.cc +index 3f8bc7b43f4f1..d5b277ad33b89 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/image_utils.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/image_utils.cc +@@ -23,11 +23,11 @@ limitations under the License. + #define STB_IMAGE_IMPLEMENTATION + #define STB_IMAGE_WRITE_IMPLEMENTATION + +-#include "absl/status/status.h" // from @com_google_absl +-#include "absl/strings/match.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl ++#include "absl/strings/match.h" // from @com_google_absl + #include "absl/strings/str_format.h" // from @com_google_absl +-#include "stb_image.h" // from @stblib +-#include "stb_image_write.h" // from @stblib ++#include "stb_image.h" // from @stblib ++#include "stb_image_write.h" // from @stblib + #include "tensorflow_lite_support/cc/port/status_macros.h" + #include "tensorflow_lite_support/cc/port/statusor.h" + #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h" +@@ -88,7 +88,9 @@ absl::Status EncodeImageToPngFile(const ImageData& image_data, + return absl::OkStatus(); + } + +-void ImageDataFree(ImageData* image) { stbi_image_free(image->pixel_data); } ++void ImageDataFree(ImageData* image) { ++ stbi_image_free(image->pixel_data); ++} + + tflite::support::StatusOr<std::unique_ptr<FrameBuffer>> + CreateFrameBufferFromImageData(const ImageData& image) { +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/image_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/image_utils.h +index 6ba5c2d6490ab..7de32ee9c0f53 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/image_utils.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/image_utils.h +@@ -15,7 +15,7 @@ limitations under the License. + #ifndef TENSORFLOW_LITE_SUPPORT_EXAMPLES_TASK_VISION_DESKTOP_UTILS_IMAGE_UTILS_H_ + #define TENSORFLOW_LITE_SUPPORT_EXAMPLES_TASK_VISION_DESKTOP_UTILS_IMAGE_UTILS_H_ + +-#include "absl/status/status.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl + #include "absl/strings/string_view.h" // from @com_google_absl + #include "tensorflow_lite_support/cc/port/integral_types.h" + #include "tensorflow_lite_support/cc/port/statusor.h" diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.cc -index 53671cb88de51..a00c8223fac99 100644 +index e0dd8a99c64c0..a0ee2dab96b6a 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.cc -@@ -20,10 +20,10 @@ limitations under the License. +@@ -20,11 +20,11 @@ limitations under the License. #include <memory> #include <string> @@ -4713,11 +5005,13 @@ +#include "absl/strings/str_cat.h" // from @com_google_absl #include "absl/strings/str_format.h" // from @com_google_absl -#include "libyuv.h" // from @libyuv +-#include "libyuv/convert_argb.h" // from @libyuv +#include "libyuv.h" // from @libyuv ++#include "libyuv/convert_argb.h" // from @libyuv #include "tensorflow_lite_support/cc/common.h" #include "tensorflow_lite_support/cc/port/integral_types.h" #include "tensorflow_lite_support/cc/port/status_macros.h" -@@ -383,7 +383,8 @@ absl::Status ResizeNv(const FrameBuffer& buffer, FrameBuffer* output_buffer) { +@@ -384,7 +384,8 @@ absl::Status ResizeNv(const FrameBuffer& buffer, FrameBuffer* output_buffer) { // Converts `buffer` to libyuv ARGB format and stores the conversion result // in `dest_argb`. @@ -4727,7 +5021,7 @@ int dest_stride_argb) { RETURN_IF_ERROR(ValidateBufferPlaneMetadata(buffer)); if (buffer.format() != FrameBuffer::Format::kRGB) { -@@ -420,7 +421,8 @@ absl::Status ConvertRgbToArgb(const FrameBuffer& buffer, uint8* dest_argb, +@@ -421,7 +422,8 @@ absl::Status ConvertRgbToArgb(const FrameBuffer& buffer, uint8* dest_argb, // Converts `src_argb` in libyuv ARGB format to FrameBuffer::kRGB format and // stores the conversion result in `output_buffer`. @@ -4737,7 +5031,7 @@ FrameBuffer* output_buffer) { RETURN_IF_ERROR(ValidateBufferPlaneMetadata(*output_buffer)); if (output_buffer->format() != FrameBuffer::Format::kRGB) { -@@ -456,7 +458,8 @@ absl::Status ConvertArgbToRgb(uint8* src_argb, int src_stride_argb, +@@ -457,7 +459,8 @@ absl::Status ConvertArgbToRgb(uint8* src_argb, int src_stride_argb, // Converts `buffer` in FrameBuffer::kRGBA format to libyuv ARGB (BGRA in // memory) format and stores the conversion result in `dest_argb`. @@ -4747,7 +5041,7 @@ int dest_stride_argb) { RETURN_IF_ERROR(ValidateBufferPlaneMetadata(buffer)); if (buffer.format() != FrameBuffer::Format::kRGBA) { -@@ -674,7 +677,8 @@ libyuv::RotationMode GetLibyuvRotationMode(int angle_deg) { +@@ -689,7 +692,8 @@ libyuv::RotationMode GetLibyuvRotationMode(int angle_deg) { } } @@ -4757,7 +5051,7 @@ FrameBuffer* output_buffer) { if (buffer.plane_count() > 1) { return CreateStatusWithPayload( -@@ -698,7 +702,8 @@ absl::Status RotateRgba(const FrameBuffer& buffer, int angle_deg, +@@ -713,7 +717,8 @@ absl::Status RotateRgba(const FrameBuffer& buffer, int angle_deg, return absl::OkStatus(); } @@ -4767,7 +5061,7 @@ FrameBuffer* output_buffer) { // libyuv does not support rotate kRGB (RGB24) foramat. In this method, the // implementation converts kRGB format to ARGB and use ARGB buffer for -@@ -731,7 +736,8 @@ absl::Status RotateRgb(const FrameBuffer& buffer, int angle_deg, +@@ -746,7 +751,8 @@ absl::Status RotateRgb(const FrameBuffer& buffer, int angle_deg, output_buffer); } @@ -4777,7 +5071,7 @@ FrameBuffer* output_buffer) { if (buffer.plane_count() > 1) { return CreateStatusWithPayload( -@@ -754,7 +760,8 @@ absl::Status RotateGray(const FrameBuffer& buffer, int angle_deg, +@@ -769,7 +775,8 @@ absl::Status RotateGray(const FrameBuffer& buffer, int angle_deg, } // Rotates YV12/YV21 frame buffer. @@ -4787,7 +5081,7 @@ FrameBuffer* output_buffer) { ASSIGN_OR_RETURN(FrameBuffer::YuvData input_data, FrameBuffer::GetYuvDataFromFrameBuffer(buffer)); -@@ -779,7 +786,8 @@ absl::Status RotateYv(const FrameBuffer& buffer, int angle_deg, +@@ -794,7 +801,8 @@ absl::Status RotateYv(const FrameBuffer& buffer, int angle_deg, // Rotates NV12/NV21 frame buffer. // TODO(b/152097364): Refactor NV12/NV21 rotation after libyuv explicitly // support that. @@ -4797,7 +5091,7 @@ FrameBuffer* output_buffer) { if (buffer.format() != FrameBuffer::Format::kNV12 && buffer.format() != FrameBuffer::Format::kNV21) { -@@ -869,8 +877,12 @@ absl::Status FlipPlaneVertically(const FrameBuffer& buffer, +@@ -884,8 +892,12 @@ absl::Status FlipPlaneVertically(const FrameBuffer& buffer, } // This method only supports kGRAY, kRGBA, and kRGB formats. @@ -4812,7 +5106,7 @@ if (buffer.plane_count() > 1) { return CreateStatusWithPayload( StatusCode::kInternal, -@@ -897,7 +909,11 @@ absl::Status CropPlane(const FrameBuffer& buffer, int x0, int y0, int x1, +@@ -912,7 +924,11 @@ absl::Status CropPlane(const FrameBuffer& buffer, int x0, int y0, int x1, // Crops NV12/NV21 FrameBuffer to the subregion defined by the top left pixel // position (x0, y0) and the bottom right pixel position (x1, y1). @@ -4825,7 +5119,7 @@ FrameBuffer* output_buffer) { ASSIGN_OR_RETURN(FrameBuffer::YuvData input_data, FrameBuffer::GetYuvDataFromFrameBuffer(buffer)); -@@ -929,7 +945,11 @@ absl::Status CropNv(const FrameBuffer& buffer, int x0, int y0, int x1, int y1, +@@ -944,7 +960,11 @@ absl::Status CropNv(const FrameBuffer& buffer, int x0, int y0, int x1, int y1, // Crops YV12/YV21 FrameBuffer to the subregion defined by the top left pixel // position (x0, y0) and the bottom right pixel position (x1, y1). @@ -4838,7 +5132,7 @@ FrameBuffer* output_buffer) { ASSIGN_OR_RETURN(FrameBuffer::YuvData input_data, FrameBuffer::GetYuvDataFromFrameBuffer(buffer)); -@@ -964,8 +984,12 @@ absl::Status CropYv(const FrameBuffer& buffer, int x0, int y0, int x1, int y1, +@@ -979,8 +999,12 @@ absl::Status CropYv(const FrameBuffer& buffer, int x0, int y0, int x1, int y1, return absl::OkStatus(); } @@ -4853,7 +5147,7 @@ FrameBuffer::Dimension crop_dimension = GetCropDimension(x0, x1, y0, y1); if (crop_dimension == output_buffer->dimension()) { switch (buffer.format()) { -@@ -1293,8 +1317,12 @@ absl::Status ResizeGray(const FrameBuffer& buffer, FrameBuffer* output_buffer) { +@@ -1308,8 +1332,12 @@ absl::Status ResizeGray(const FrameBuffer& buffer, FrameBuffer* output_buffer) { } // This method only supports kGRAY, kRGBA, and kRGB formats. @@ -4868,7 +5162,7 @@ FrameBuffer::Dimension crop_dimension = GetCropDimension(x0, x1, y0, y1); if (crop_dimension == output_buffer->dimension()) { return CropPlane(buffer, x0, y0, x1, y1, output_buffer); -@@ -1328,8 +1356,11 @@ absl::Status CropResize(const FrameBuffer& buffer, int x0, int y0, int x1, +@@ -1343,8 +1371,11 @@ absl::Status CropResize(const FrameBuffer& buffer, int x0, int y0, int x1, } // namespace @@ -4882,7 +5176,7 @@ FrameBuffer* output_buffer) { RETURN_IF_ERROR(ValidateBufferPlaneMetadata(buffer)); RETURN_IF_ERROR(ValidateBufferPlaneMetadata(*output_buffer)); -@@ -1410,7 +1441,8 @@ absl::Status LibyuvFrameBufferUtils::Rotate(const FrameBuffer& buffer, +@@ -1425,7 +1456,8 @@ absl::Status LibyuvFrameBufferUtils::Rotate(const FrameBuffer& buffer, } absl::Status LibyuvFrameBufferUtils::FlipHorizontally( @@ -4892,7 +5186,7 @@ RETURN_IF_ERROR(ValidateBufferPlaneMetadata(buffer)); RETURN_IF_ERROR(ValidateBufferPlaneMetadata(*output_buffer)); RETURN_IF_ERROR(ValidateFlipBufferInputs(buffer, *output_buffer)); -@@ -1438,7 +1470,8 @@ absl::Status LibyuvFrameBufferUtils::FlipHorizontally( +@@ -1453,7 +1485,8 @@ absl::Status LibyuvFrameBufferUtils::FlipHorizontally( } absl::Status LibyuvFrameBufferUtils::FlipVertically( @@ -4992,7 +5286,7 @@ #include "tensorflow_lite_support/cc/port/gtest.h" diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/processor/image_preprocessor_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/processor/image_preprocessor_test.cc -index d0a7e33129e7e..9ae943548dc63 100644 +index 9a00e2f9e89a1..ef0e783e97c3e 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/processor/image_preprocessor_test.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/processor/image_preprocessor_test.cc @@ -46,8 +46,8 @@ constexpr char kTestDataDirectory[] = @@ -5425,7 +5719,7 @@ std::vector<core::Category> expected_class = { {"label0", 255}, diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/text_embedder_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/text_embedder_test.cc -index 31195ce525b69..931a44e072c7a 100644 +index 5a86a288b4624..b097813ecedf7 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/text_embedder_test.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/text_embedder_test.cc @@ -17,7 +17,7 @@ limitations under the License. @@ -5448,7 +5742,27 @@ return options; } -@@ -130,7 +130,7 @@ TEST(EmbedTest, SucceedsWithRegexModel) { +@@ -130,7 +130,7 @@ TEST(EmbedTest, SucceedsWithMobileBertModel) { + TextEmbedderOptions options = GetBasicOptions(kMobileBert); + // No Embedding options means all head get a default option. + SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextEmbedder> text_embedder, +- TextEmbedder::CreateFromOptions(options)); ++ TextEmbedder::CreateFromOptions(options)); + + SUPPORT_ASSERT_OK_AND_ASSIGN( + auto result0, +@@ -141,8 +141,8 @@ TEST(EmbedTest, SucceedsWithMobileBertModel) { + EXPECT_NEAR(result0.embeddings(0).feature_vector().value_float(0), 19.9016f, + kValueDiffTolerance); + +- SUPPORT_ASSERT_OK_AND_ASSIGN(auto result1, +- text_embedder->Embed("what a great and fantastic trip")); ++ SUPPORT_ASSERT_OK_AND_ASSIGN( ++ auto result1, text_embedder->Embed("what a great and fantastic trip")); + EXPECT_EQ(result1.embeddings_size(), 1); + EXPECT_EQ(result1.embeddings(0).feature_vector().value_float_size(), 512); + +@@ -162,7 +162,7 @@ TEST(EmbedTest, SucceedsWithRegexModel) { TextEmbedderOptions options = GetBasicOptions(kRegexOneEmbeddingModel); // No Embedding options means all head get a default option. SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextEmbedder> text_embedder, @@ -5457,7 +5771,7 @@ SUPPORT_ASSERT_OK_AND_ASSIGN( auto result0, -@@ -141,8 +141,8 @@ TEST(EmbedTest, SucceedsWithRegexModel) { +@@ -173,8 +173,8 @@ TEST(EmbedTest, SucceedsWithRegexModel) { EXPECT_NEAR(result0.embeddings(0).feature_vector().value_float(0), 0.0309356f, kValueDiffTolerance); @@ -5468,7 +5782,7 @@ EXPECT_EQ(result1.embeddings_size(), 1); EXPECT_EQ(result1.embeddings(0).feature_vector().value_float_size(), 16); -@@ -174,8 +174,8 @@ TEST(EmbedTest, SucceedsWithUniversalSentenceEncoder) { +@@ -206,8 +206,8 @@ TEST(EmbedTest, SucceedsWithUniversalSentenceEncoder) { EXPECT_NEAR(result0.embeddings(0).feature_vector().value_float(0), 1.422951f, kValueDiffTolerance); @@ -5479,7 +5793,7 @@ EXPECT_EQ(result1.embeddings_size(), 1); EXPECT_EQ(result1.embeddings(0).feature_vector().value_float_size(), 100); -@@ -195,7 +195,7 @@ TEST(GetEmbeddingDimension, Succeeds) { +@@ -227,7 +227,7 @@ TEST(GetEmbeddingDimension, Succeeds) { // Create embedder. TextEmbedderOptions options = GetBasicOptions(kMobileBert); SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextEmbedder> text_embedder, @@ -5488,7 +5802,7 @@ EXPECT_EQ(text_embedder->GetEmbeddingDimension(0), 512); EXPECT_EQ(text_embedder->GetEmbeddingDimension(1), -1); -@@ -206,7 +206,7 @@ TEST(GetNumberOfOutputLayers, Succeeds) { +@@ -238,7 +238,7 @@ TEST(GetNumberOfOutputLayers, Succeeds) { TextEmbedderOptions options = GetBasicOptions(kMobileBert); // No Embedding options means all head get a default option. SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextEmbedder> text_embedder, @@ -5497,6 +5811,70 @@ EXPECT_EQ(text_embedder->GetNumberOfOutputLayers(), kNumberOfOutputLayers); } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/text_searcher_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/text_searcher_test.cc +index fec09a1ad77cc..f38615c5b3092 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/text_searcher_test.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/text_searcher_test.cc +@@ -18,9 +18,9 @@ limitations under the License. + #include <memory> + #include <string> + +-#include "absl/flags/flag.h" // from @com_google_absl +-#include "absl/status/status.h" // from @com_google_absl +-#include "absl/strings/cord.h" // from @com_google_absl ++#include "absl/flags/flag.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl ++#include "absl/strings/cord.h" // from @com_google_absl + #include "absl/strings/str_cat.h" // from @com_google_absl + #include "tensorflow/lite/core/api/op_resolver.h" + #include "tensorflow/lite/core/shims/cc/shims_test_util.h" +@@ -219,7 +219,8 @@ TEST_P(CreateFromOptionsTest, FailsWithInvalidMaxResults) { + } + + INSTANTIATE_TEST_SUITE_P( +- CreateFromOptionsTest, CreateFromOptionsTest, ++ CreateFromOptionsTest, ++ CreateFromOptionsTest, + Values(CreateFromOptionsParams{.name = "Bert", + .embedder_model_name = kMobileBertEmbedder, + .searcher_model_name = kMobileBertSearcher, +@@ -267,7 +268,7 @@ TEST_P(SearchTest, SucceedsWithStandaloneIndex) { + + // Perform search. + SUPPORT_ASSERT_OK_AND_ASSIGN(const SearchResult& result, +- searcher->Search("The weather was excellent.")); ++ searcher->Search("The weather was excellent.")); + + // Check results. + ExpectApproximatelyEqual( +@@ -288,7 +289,7 @@ TEST_P(SearchTest, SucceedsWithMetadataIndex) { + + // Perform search. + SUPPORT_ASSERT_OK_AND_ASSIGN(const SearchResult& result, +- searcher->Search("The weather was excellent.")); ++ searcher->Search("The weather was excellent.")); + + // Check results. + ExpectApproximatelyEqual( +@@ -313,7 +314,7 @@ TEST_P(SearchTest, SucceedsWithMaxResults) { + + // Perform search. + SUPPORT_ASSERT_OK_AND_ASSIGN(const SearchResult& result, +- searcher->Search("The weather was excellent.")); ++ searcher->Search("The weather was excellent.")); + + // Check results. + SearchResult all_results = +@@ -327,7 +328,8 @@ TEST_P(SearchTest, SucceedsWithMaxResults) { + } + + INSTANTIATE_TEST_SUITE_P( +- SearchTest, SearchTest, ++ SearchTest, ++ SearchTest, + Values( + SearchParams{ + .name = "Bert", diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/universal_sentence_encoder_qa_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/universal_sentence_encoder_qa_test.cc index 2529060cab275..5f0535b5c1438 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/universal_sentence_encoder_qa_test.cc @@ -5549,7 +5927,7 @@ *input.mutable_responses()->Add()->mutable_text_encoding() = resp1; *input.mutable_responses()->Add()->mutable_text_encoding() = resp2; diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_classifier_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_classifier_test.cc -index 86711635ce467..c40836ed3b125 100644 +index 6a0ce66dde9b5..2daf293b48f05 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_classifier_test.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_classifier_test.cc @@ -17,9 +17,9 @@ limitations under the License. @@ -5882,7 +6260,7 @@ ExpectApproximatelyEqual( result, diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_embedder_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_embedder_test.cc -index d5606d12440b0..8877f28b98beb 100644 +index 6ce017d3f1728..41226f602a26b 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_embedder_test.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_embedder_test.cc @@ -17,7 +17,7 @@ limitations under the License. @@ -6130,8 +6508,176 @@ EXPECT_EQ(embedder->GetNumberOfOutputLayers(), 1); } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_searcher_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_searcher_test.cc +index 0b1f3b11b383c..00183eb65b5df 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_searcher_test.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_searcher_test.cc +@@ -18,9 +18,9 @@ limitations under the License. + #include <memory> + #include <string> + +-#include "absl/flags/flag.h" // from @com_google_absl +-#include "absl/status/status.h" // from @com_google_absl +-#include "absl/strings/cord.h" // from @com_google_absl ++#include "absl/flags/flag.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl ++#include "absl/strings/cord.h" // from @com_google_absl + #include "absl/strings/str_cat.h" // from @com_google_absl + #include "tensorflow/lite/core/shims/cc/shims_test_util.h" + #include "tensorflow_lite_support/cc/common.h" +@@ -66,8 +66,8 @@ constexpr char kMobileNetV3Searcher[] = + "mobilenet_v3_small_100_224_searcher.tflite"; + + StatusOr<ImageData> LoadImage(std::string image_name) { +- return DecodeImageFromFile(JoinPath("./" /*test src dir*/, +- kTestDataDirectory, image_name)); ++ return DecodeImageFromFile( ++ JoinPath("./" /*test src dir*/, kTestDataDirectory, image_name)); + } + + // Checks that the two provided `SearchResult` protos are equal, with a +@@ -88,9 +88,8 @@ class CreateFromOptionsTest : public tflite_shims::testing::Test {}; + + TEST_F(CreateFromOptionsTest, SucceedsWithStandaloneIndex) { + ImageSearcherOptions options; +- options.mutable_base_options()->mutable_model_file()->set_file_name( +- JoinPath("./" /*test src dir*/, kTestDataDirectory, +- kMobileNetV3Embedder)); ++ options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath( ++ "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3Embedder)); + options.mutable_embedding_options()->set_l2_normalize(true); + options.mutable_search_options()->mutable_index_file()->set_file_name( + JoinPath("./" /*test src dir*/, kTestDataDirectory, kIndex)); +@@ -100,9 +99,8 @@ TEST_F(CreateFromOptionsTest, SucceedsWithStandaloneIndex) { + + TEST_F(CreateFromOptionsTest, SucceedsWithMetadataIndex) { + ImageSearcherOptions options; +- options.mutable_base_options()->mutable_model_file()->set_file_name( +- JoinPath("./" /*test src dir*/, kTestDataDirectory, +- kMobileNetV3Searcher)); ++ options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath( ++ "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3Searcher)); + options.mutable_embedding_options()->set_l2_normalize(true); + + SUPPORT_ASSERT_OK(ImageSearcher::CreateFromOptions(options)); +@@ -129,9 +127,8 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingModel) { + + TEST_F(CreateFromOptionsTest, FailsWithMissingIndex) { + ImageSearcherOptions options; +- options.mutable_base_options()->mutable_model_file()->set_file_name( +- JoinPath("./" /*test src dir*/, kTestDataDirectory, +- kMobileNetV3Embedder)); ++ options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath( ++ "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3Embedder)); + options.mutable_embedding_options()->set_l2_normalize(true); + + StatusOr<std::unique_ptr<ImageSearcher>> image_searcher_or = +@@ -151,9 +148,8 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingIndex) { + + TEST_F(CreateFromOptionsTest, FailsWithQuantization) { + ImageSearcherOptions options; +- options.mutable_base_options()->mutable_model_file()->set_file_name( +- JoinPath("./" /*test src dir*/, kTestDataDirectory, +- kMobileNetV3Embedder)); ++ options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath( ++ "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3Embedder)); + options.mutable_embedding_options()->set_l2_normalize(true); + options.mutable_embedding_options()->set_quantize(true); + options.mutable_search_options()->mutable_index_file()->set_file_name( +@@ -174,9 +170,8 @@ TEST_F(CreateFromOptionsTest, FailsWithQuantization) { + + TEST_F(CreateFromOptionsTest, FailsWithInvalidMaxResults) { + ImageSearcherOptions options; +- options.mutable_base_options()->mutable_model_file()->set_file_name( +- JoinPath("./" /*test src dir*/, kTestDataDirectory, +- kMobileNetV3Embedder)); ++ options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath( ++ "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3Embedder)); + options.mutable_embedding_options()->set_l2_normalize(true); + options.mutable_search_options()->mutable_index_file()->set_file_name( + JoinPath("./" /*test src dir*/, kTestDataDirectory, kIndex)); +@@ -197,14 +192,13 @@ TEST_F(CreateFromOptionsTest, FailsWithInvalidMaxResults) { + TEST(SearchTest, SucceedsWithStandaloneIndex) { + // Create Searcher. + ImageSearcherOptions options; +- options.mutable_base_options()->mutable_model_file()->set_file_name( +- JoinPath("./" /*test src dir*/, kTestDataDirectory, +- kMobileNetV3Embedder)); ++ options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath( ++ "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3Embedder)); + options.mutable_embedding_options()->set_l2_normalize(true); + options.mutable_search_options()->mutable_index_file()->set_file_name( + JoinPath("./" /*test src dir*/, kTestDataDirectory, kIndex)); + SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSearcher> searcher, +- ImageSearcher::CreateFromOptions(options)); ++ ImageSearcher::CreateFromOptions(options)); + // Load image. + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image, LoadImage("burger.jpg")); + std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer( +@@ -212,7 +206,7 @@ TEST(SearchTest, SucceedsWithStandaloneIndex) { + + // Perform search. + SUPPORT_ASSERT_OK_AND_ASSIGN(const SearchResult& result, +- searcher->Search(*frame_buffer)); ++ searcher->Search(*frame_buffer)); + ImageDataFree(&image); + + // Check results. +@@ -229,12 +223,11 @@ TEST(SearchTest, SucceedsWithStandaloneIndex) { + TEST(SearchTest, SucceedsWithMetadataIndex) { + // Create Searcher. + ImageSearcherOptions options; +- options.mutable_base_options()->mutable_model_file()->set_file_name( +- JoinPath("./" /*test src dir*/, kTestDataDirectory, +- kMobileNetV3Searcher)); ++ options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath( ++ "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3Searcher)); + options.mutable_embedding_options()->set_l2_normalize(true); + SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSearcher> searcher, +- ImageSearcher::CreateFromOptions(options)); ++ ImageSearcher::CreateFromOptions(options)); + // Load image. + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image, LoadImage("burger.jpg")); + std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer( +@@ -242,7 +235,7 @@ TEST(SearchTest, SucceedsWithMetadataIndex) { + + // Perform search. + SUPPORT_ASSERT_OK_AND_ASSIGN(const SearchResult& result, +- searcher->Search(*frame_buffer)); ++ searcher->Search(*frame_buffer)); + ImageDataFree(&image); + + // Check results. +@@ -259,15 +252,14 @@ TEST(SearchTest, SucceedsWithMetadataIndex) { + TEST(SearchTest, SucceedsWithMaxResults) { + // Create Searcher. + ImageSearcherOptions options; +- options.mutable_base_options()->mutable_model_file()->set_file_name( +- JoinPath("./" /*test src dir*/, kTestDataDirectory, +- kMobileNetV3Embedder)); ++ options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath( ++ "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3Embedder)); + options.mutable_embedding_options()->set_l2_normalize(true); + options.mutable_search_options()->mutable_index_file()->set_file_name( + JoinPath("./" /*test src dir*/, kTestDataDirectory, kIndex)); + options.mutable_search_options()->set_max_results(2); + SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSearcher> searcher, +- ImageSearcher::CreateFromOptions(options)); ++ ImageSearcher::CreateFromOptions(options)); + // Load image. + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image, LoadImage("burger.jpg")); + std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer( +@@ -275,7 +267,7 @@ TEST(SearchTest, SucceedsWithMaxResults) { + + // Perform search. + SUPPORT_ASSERT_OK_AND_ASSIGN(const SearchResult& result, +- searcher->Search(*frame_buffer)); ++ searcher->Search(*frame_buffer)); + ImageDataFree(&image); + + // Check results. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_segmenter_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_segmenter_test.cc -index 3aab0bbee48ef..dc768a43a8726 100644 +index e32b8e4c27524..8671b68c3b884 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_segmenter_test.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_segmenter_test.cc @@ -17,9 +17,9 @@ limitations under the License. @@ -6415,7 +6961,7 @@ EXPECT_EQ(result.segmentation_size(), 1); const Segmentation& segmentation = result.segmentation(0); diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/object_detector_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/object_detector_test.cc -index ef1f6509080ed..4a33e4b479354 100644 +index a4f35574d7bfe..6c0f395868e20 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/object_detector_test.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/object_detector_test.cc @@ -17,9 +17,9 @@ limitations under the License. @@ -7095,7 +7641,7 @@ jclass e_class = env->FindClass(clazz); if (strcmp(clazz, kAssertionError) == 0) { diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/jni_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/jni_utils.h -index 7f0674d3c9187..7caf49e479859 100644 +index 6d15bb43e75b3..f92f838bb9a71 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/jni_utils.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/jni_utils.h @@ -22,7 +22,7 @@ limitations under the License. @@ -7107,17 +7653,18 @@ #include "absl/strings/string_view.h" // from @com_google_absl #include "tensorflow_lite_support/cc/port/configuration_proto_inc.h" #include "tensorflow_lite_support/cc/port/statusor.h" -@@ -57,7 +57,8 @@ T CheckNotNull(JNIEnv* env, T&& t) { - // Converts a std::vector<T> into a Java ArrayList using a converter, which - // processes a single element in the vector before adding it to the ArrayList. - template <typename T> --jobject ConvertVectorToArrayList(JNIEnv* env, const std::vector<T>& results, -+jobject ConvertVectorToArrayList(JNIEnv* env, -+ const std::vector<T>& results, - std::function<jobject(T)> converter) { +@@ -59,7 +59,9 @@ T CheckNotNull(JNIEnv* env, T&& t) { + // interable before adding it to the ArrayList. + template <typename Iterator> + jobject ConvertVectorToArrayList( +- JNIEnv* env, const Iterator& begin, const Iterator& end, ++ JNIEnv* env, ++ const Iterator& begin, ++ const Iterator& end, + std::function<jobject(typename std::iterator_traits<Iterator>::value_type)> + converter) { jclass array_list_class = env->FindClass("java/util/ArrayList"); - jmethodID array_list_ctor = -@@ -91,7 +92,8 @@ jbyteArray CreateByteArray(JNIEnv* env, const jbyte* data, int num_bytes); +@@ -94,7 +96,8 @@ jbyteArray CreateByteArray(JNIEnv* env, const jbyte* data, int num_bytes); void ThrowException(JNIEnv* env, const char* clazz, const char* fmt, ...); @@ -8204,6 +8751,115 @@ "If true, inference will be delegated to a connected Coral Edge TPU " "device."); +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/text_embedder_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/text_embedder_demo.cc +index 875b5f4a771bd..eca8a002d3293 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/text_embedder_demo.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/text_embedder_demo.cc +@@ -24,9 +24,9 @@ limitations under the License. + #include <iostream> + #include <memory> + +-#include "absl/flags/flag.h" // from @com_google_absl +-#include "absl/flags/parse.h" // from @com_google_absl +-#include "absl/status/status.h" // from @com_google_absl ++#include "absl/flags/flag.h" // from @com_google_absl ++#include "absl/flags/parse.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl + #include "absl/strings/str_format.h" // from @com_google_absl + #include "tensorflow_lite_support/cc/port/configuration_proto_inc.h" + #include "tensorflow_lite_support/cc/port/status_macros.h" +@@ -36,19 +36,29 @@ limitations under the License. + #include "tensorflow_lite_support/cc/task/text/text_embedder.h" + #include "tensorflow_lite_support/examples/task/text/desktop/universal_sentence_encoder_qa_op_resolver.h" + +-ABSL_FLAG(std::string, model_path, "", ++ABSL_FLAG(std::string, ++ model_path, ++ "", + "Absolute path to the '.tflite' text embedder model."); +-ABSL_FLAG(std::string, first_sentence, "", ++ABSL_FLAG(std::string, ++ first_sentence, ++ "", + "First sentence, whose feature vector will be extracted and compared " + "to the second sentence using cosine similarity."); +-ABSL_FLAG(std::string, second_sentence, "", ++ABSL_FLAG(std::string, ++ second_sentence, ++ "", + "Second sentence, whose feature vector will be extracted and " + "compared to the first sentence using cosine similarity."); +-ABSL_FLAG(bool, l2_normalize, false, ++ABSL_FLAG(bool, ++ l2_normalize, ++ false, + "If true, the raw feature vectors returned by the image embedder " + "will be normalized with L2-norm. Generally only needed if the model " + "doesn't already contain a L2_NORMALIZATION TFLite Op."); +-ABSL_FLAG(bool, use_coral, false, ++ABSL_FLAG(bool, ++ use_coral, ++ false, + "If true, inference will be delegated to a connected Coral Edge TPU " + "device."); + +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/text_searcher_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/text_searcher_demo.cc +index 5ea9b7e63b50e..0299428964797 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/text_searcher_demo.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/text_searcher_demo.cc +@@ -24,9 +24,9 @@ limitations under the License. + #include <iostream> + #include <memory> + +-#include "absl/flags/flag.h" // from @com_google_absl +-#include "absl/flags/parse.h" // from @com_google_absl +-#include "absl/status/status.h" // from @com_google_absl ++#include "absl/flags/flag.h" // from @com_google_absl ++#include "absl/flags/parse.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl + #include "absl/strings/str_format.h" // from @com_google_absl + #include "tensorflow_lite_support/cc/port/configuration_proto_inc.h" + #include "tensorflow_lite_support/cc/port/status_macros.h" +@@ -39,21 +39,33 @@ limitations under the License. + #include "tensorflow_lite_support/cc/task/text/text_searcher.h" + #include "tensorflow_lite_support/examples/task/text/desktop/universal_sentence_encoder_qa_op_resolver.h" + +-ABSL_FLAG(std::string, model_path, "", ++ABSL_FLAG(std::string, ++ model_path, ++ "", + "Absolute path to the '.tflite' text embedder model."); +-ABSL_FLAG(std::string, index_path, "", ++ABSL_FLAG(std::string, ++ index_path, ++ "", + "Absolute path to the index to search into. Mandatory only if the " + "index is not attached to the output tensor metadata of the embedder " + "model as an AssociatedFile with type SCANN_INDEX_FILE."); +-ABSL_FLAG(std::string, input_sentence, "", ++ABSL_FLAG(std::string, ++ input_sentence, ++ "", + "Input sentence whose nearest-neighbors to search for in the index."); +-ABSL_FLAG(int32, max_results, 5, ++ABSL_FLAG(int32, ++ max_results, ++ 5, + "Maximum number of nearest-neghbors to display."); +-ABSL_FLAG(bool, l2_normalize, false, ++ABSL_FLAG(bool, ++ l2_normalize, ++ false, + "If true, the raw feature vectors returned by the image embedder " + "will be normalized with L2-norm. Generally only needed if the model " + "doesn't already contain a L2_NORMALIZATION TFLite Op."); +-ABSL_FLAG(bool, use_coral, false, ++ABSL_FLAG(bool, ++ use_coral, ++ false, + "If true, inference will be delegated to a connected Coral Edge TPU " + "device."); + diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/universal_sentence_encoder_qa_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/universal_sentence_encoder_qa_demo.cc index 076a60a2330af..f7621a5a8a1b4 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/universal_sentence_encoder_qa_demo.cc @@ -8243,7 +8899,7 @@ "Candidate answers seperated by `:`."); diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_classifier_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_classifier_demo.cc -index 8b2ed939686b3..bd2aaaf188726 100644 +index f29bd2de9c535..0904920faa7dd 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_classifier_demo.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_classifier_demo.cc @@ -22,9 +22,9 @@ limitations under the License. @@ -8261,7 +8917,7 @@ #include "tensorflow_lite_support/cc/task/core/external_file_handler.h" @@ -36,29 +36,43 @@ limitations under the License. #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h" - #include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h" + #include "tensorflow_lite_support/cc/task/vision/utils/image_utils.h" -ABSL_FLAG(std::string, model_path, "", +ABSL_FLAG(std::string, @@ -8311,7 +8967,7 @@ "device."); diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_embedder_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_embedder_demo.cc -index 722194f34ee5e..040878aa37841 100644 +index 50d615a486751..f8b1796bc3865 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_embedder_demo.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_embedder_demo.cc @@ -26,9 +26,9 @@ limitations under the License. @@ -8329,7 +8985,7 @@ #include "tensorflow_lite_support/cc/task/core/external_file_handler.h" @@ -39,28 +39,40 @@ limitations under the License. #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h" - #include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h" + #include "tensorflow_lite_support/cc/task/vision/utils/image_utils.h" -ABSL_FLAG(std::string, model_path, "", +ABSL_FLAG(std::string, @@ -8375,7 +9031,7 @@ "device."); diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_searcher_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_searcher_demo.cc -index c23d90f982084..a188b5d558343 100644 +index b661447614bc7..e4074f76dba5b 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_searcher_demo.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_searcher_demo.cc @@ -25,9 +25,9 @@ limitations under the License. @@ -8391,9 +9047,9 @@ #include "absl/strings/str_format.h" // from @com_google_absl #include "tensorflow_lite_support/cc/port/status_macros.h" #include "tensorflow_lite_support/cc/port/statusor.h" -@@ -42,21 +42,33 @@ limitations under the License. +@@ -42,23 +42,35 @@ limitations under the License. #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h" - #include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h" + #include "tensorflow_lite_support/cc/task/vision/utils/image_utils.h" -ABSL_FLAG(std::string, model_path, "", +ABSL_FLAG(std::string, @@ -8404,7 +9060,9 @@ +ABSL_FLAG(std::string, + index_path, + "", - "Absolute path to the index to search into."); + "Absolute path to the index to search into. Mandatory only if the " + "index is not attached to the output tensor metadata of the embedder " + "model as an AssociatedFile with type SCANN_INDEX_FILE."); -ABSL_FLAG(std::string, image_path, "", +ABSL_FLAG(std::string, + image_path, @@ -8412,11 +9070,11 @@ "Absolute path to the image to search. The image must be RGB or " "RGBA (grayscale is not supported). The image EXIF orientation " "flag, if any, is NOT taken into account."); --ABSL_FLAG(int32, num_results, 5, +-ABSL_FLAG(int32, max_results, 5, +ABSL_FLAG(int32, -+ num_results, ++ max_results, + 5, - "Number of nearest-neighbor results to display."); + "Maximum number of nearest-neighbor results to display."); -ABSL_FLAG(bool, l2_normalize, false, +ABSL_FLAG(bool, + l2_normalize, @@ -8432,7 +9090,7 @@ "device."); diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_segmenter_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_segmenter_demo.cc -index 2cb606e011aca..6487fe92166cd 100644 +index 5a566ecbcf921..fdc787288fa06 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_segmenter_demo.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_segmenter_demo.cc @@ -23,10 +23,10 @@ limitations under the License. @@ -8452,7 +9110,7 @@ #include "tensorflow_lite_support/cc/task/core/external_file_handler.h" @@ -37,16 +37,24 @@ limitations under the License. #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h" - #include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h" + #include "tensorflow_lite_support/cc/task/vision/utils/image_utils.h" -ABSL_FLAG(std::string, model_path, "", +ABSL_FLAG(std::string, @@ -8480,7 +9138,7 @@ "device."); diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/object_detector_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/object_detector_demo.cc -index 0130b4550b9d9..9208439df6263 100644 +index 20f7403207c2e..fd000fccf2f29 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/object_detector_demo.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/object_detector_demo.cc @@ -24,10 +24,10 @@ limitations under the License. @@ -8500,7 +9158,7 @@ #include "tensorflow_lite_support/cc/task/core/external_file_handler.h" @@ -40,32 +40,48 @@ limitations under the License. #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h" - #include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h" + #include "tensorflow_lite_support/cc/task/vision/utils/image_utils.h" -ABSL_FLAG(std::string, model_path, "", +ABSL_FLAG(std::string, @@ -8555,50 +9213,6 @@ "If true, inference will be delegated to a connected Coral Edge TPU " "device."); -diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.cc -index e0a8a4c36ddc4..d5c0c589e375c 100644 ---- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.cc -+++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.cc -@@ -23,11 +23,11 @@ limitations under the License. - #define STB_IMAGE_IMPLEMENTATION - #define STB_IMAGE_WRITE_IMPLEMENTATION - --#include "absl/status/status.h" // from @com_google_absl --#include "absl/strings/match.h" // from @com_google_absl -+#include "absl/status/status.h" // from @com_google_absl -+#include "absl/strings/match.h" // from @com_google_absl - #include "absl/strings/str_format.h" // from @com_google_absl --#include "stb_image.h" // from @stblib --#include "stb_image_write.h" // from @stblib -+#include "stb_image.h" // from @stblib -+#include "stb_image_write.h" // from @stblib - #include "tensorflow_lite_support/cc/port/status_macros.h" - #include "tensorflow_lite_support/cc/port/statusor.h" - #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h" -@@ -88,7 +88,9 @@ absl::Status EncodeImageToPngFile(const ImageData& image_data, - return absl::OkStatus(); - } - --void ImageDataFree(ImageData* image) { stbi_image_free(image->pixel_data); } -+void ImageDataFree(ImageData* image) { -+ stbi_image_free(image->pixel_data); -+} - - tflite::support::StatusOr<std::unique_ptr<FrameBuffer>> - CreateFrameBufferFromImageData(const ImageData& image) { -diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h -index 6ba5c2d6490ab..7de32ee9c0f53 100644 ---- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h -+++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h -@@ -15,7 +15,7 @@ limitations under the License. - #ifndef TENSORFLOW_LITE_SUPPORT_EXAMPLES_TASK_VISION_DESKTOP_UTILS_IMAGE_UTILS_H_ - #define TENSORFLOW_LITE_SUPPORT_EXAMPLES_TASK_VISION_DESKTOP_UTILS_IMAGE_UTILS_H_ - --#include "absl/status/status.h" // from @com_google_absl -+#include "absl/status/status.h" // from @com_google_absl - #include "absl/strings/string_view.h" // from @com_google_absl - #include "tensorflow_lite_support/cc/port/integral_types.h" - #include "tensorflow_lite_support/cc/port/statusor.h" diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommon.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommon.h index a4fee55abe158..2ca42fb7f3fbe 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/sources/TFLCommon.h @@ -8726,6 +9340,248 @@ return NO; } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLFloatBuffer.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLFloatBuffer.h +index 79b6ba238e982..a5db97038a047 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLFloatBuffer.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLFloatBuffer.h +@@ -23,26 +23,28 @@ NS_ASSUME_NONNULL_BEGIN + @property(nonatomic, readonly) NSUInteger size; + + /** Pointer to float array wrapped by `TFLFloatBuffer`. */ +-@property(nonatomic, readonly) float *data; ++@property(nonatomic, readonly) float* data; + + /** +- * Initializes a new `TFLFloatBuffer` by copying the elements of the given float data array. ++ * Initializes a new `TFLFloatBuffer` by copying the elements of the given float ++ * data array. + * +- * @param data A pointer to a float data array whose values are to be copied into the buffer. ++ * @param data A pointer to a float data array whose values are to be copied ++ * into the buffer. + * @param size Size of the array float data array. + * +- * @return A new instance of `TFLFloatBuffer` initialized with the elements of the given float data +- * array. ++ * @return A new instance of `TFLFloatBuffer` initialized with the elements of ++ * the given float data array. + */ +-- (instancetype)initWithData:(float *)data size:(NSUInteger)size; ++- (instancetype)initWithData:(float*)data size:(NSUInteger)size; + + /** + * Initializes a `TFLFloatBuffer` of the specified size with zeros. + * + * @param size Number of elements the `TFLFloatBuffer` can hold. + * +- * @return A new instance of `TFLFloatBuffer` of the given size with all elements initialized to +- * zero. ++ * @return A new instance of `TFLFloatBuffer` of the given size with all ++ * elements initialized to zero. + */ + - (instancetype)initWithSize:(NSUInteger)size; + +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLFloatBuffer.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLFloatBuffer.m +index 24d50affb27aa..d32fc4363efc2 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLFloatBuffer.m ++++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLFloatBuffer.m +@@ -16,7 +16,7 @@ + + @implementation TFLFloatBuffer + +-- (instancetype)initWithData:(float *)data size:(NSUInteger)size { ++- (instancetype)initWithData:(float*)data size:(NSUInteger)size { + self = [self init]; + if (self) { + _size = size; +@@ -43,7 +43,7 @@ + return self; + } + +-- (id)copyWithZone:(NSZone *)zone { ++- (id)copyWithZone:(NSZone*)zone { + return [[TFLFloatBuffer alloc] initWithData:_data size:_size]; + } + +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLRingBuffer.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLRingBuffer.h +index 5a0ab68974b88..b300de6b94d89 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLRingBuffer.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLRingBuffer.h +@@ -17,13 +17,14 @@ + + NS_ASSUME_NONNULL_BEGIN + +-/** An wrapper class which stores a buffer that is written in circular fashion. */ ++/** An wrapper class which stores a buffer that is written in circular fashion. ++ */ + @interface TFLRingBuffer : NSObject + + /** + * A copy of all the internal ring buffer elements in order. + */ +-@property(nullable, nonatomic, readonly) TFLFloatBuffer *floatBuffer; ++@property(nullable, nonatomic, readonly) TFLFloatBuffer* floatBuffer; + + /** + * Capacity of the ring buffer in number of elements. +@@ -36,34 +37,37 @@ NS_ASSUME_NONNULL_BEGIN + * + * @param size Size of the ring buffer. + * +- * @return A new instance of `TFLRingBuffer` with the given size and all elements +- * initialized to zero. ++ * @return A new instance of `TFLRingBuffer` with the given size and all ++ * elements initialized to zero. + */ + - (instancetype)initWithBufferSize:(NSUInteger)size; + + /** +- * Loads a slice of a float array to the ring buffer. If the float array is longer than ring +- * buffer's capacity, samples with lower indices in the array will be ignored. ++ * Loads a slice of a float array to the ring buffer. If the float array is ++ * longer than ring buffer's capacity, samples with lower indices in the array ++ * will be ignored. + * + * @return Boolean indicating success or failure of loading operation. + */ +-- (BOOL)loadBuffer:(TFLFloatBuffer *)sourceBuffer ++- (BOOL)loadBuffer:(TFLFloatBuffer*)sourceBuffer + offset:(NSUInteger)offset + size:(NSUInteger)size +- error:(NSError **)error; ++ error:(NSError**)error; + + /** +- * Returns a `TFLFloatBuffer` with a copy of size number of the ring buffer elements in order +- * starting at offset, i.e, buffer[offset:offset+size]. ++ * Returns a `TFLFloatBuffer` with a copy of size number of the ring buffer ++ * elements in order starting at offset, i.e, buffer[offset:offset+size]. + * +- * @param offset Offset in the ring buffer from which elements are to be returned. ++ * @param offset Offset in the ring buffer from which elements are to be ++ * returned. + * + * @param size Number of elements to be returned. + * +- * @return A new `TFLFloatBuffer` if offset + size is within the bounds of the ring buffer, +- * otherwise nil. ++ * @return A new `TFLFloatBuffer` if offset + size is within the bounds of the ++ * ring buffer, otherwise nil. + */ +-- (nullable TFLFloatBuffer *)floatBufferWithOffset:(NSUInteger)offset size:(NSUInteger)size; ++- (nullable TFLFloatBuffer*)floatBufferWithOffset:(NSUInteger)offset ++ size:(NSUInteger)size; + + /** + * Clears the `TFLRingBuffer` by setting all the elements to zero . +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLRingBuffer.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLRingBuffer.m +index 675f7058fff61..57495409f51c8 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLRingBuffer.m ++++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLRingBuffer.m +@@ -18,7 +18,7 @@ + + @implementation TFLRingBuffer { + NSUInteger _nextIndex; +- TFLFloatBuffer *_buffer; ++ TFLFloatBuffer* _buffer; + } + + - (instancetype)initWithBufferSize:(NSUInteger)size { +@@ -29,18 +29,18 @@ + return self; + } + +-- (BOOL)loadBuffer:(TFLFloatBuffer *)sourceBuffer ++- (BOOL)loadBuffer:(TFLFloatBuffer*)sourceBuffer + offset:(NSUInteger)offset + size:(NSUInteger)size +- error:(NSError **)error { ++ error:(NSError**)error { + NSUInteger sizeToCopy = size; + NSUInteger newOffset = offset; + + if (offset + size > sourceBuffer.size) { +- [TFLCommonUtils +- createCustomError:error +- withCode:TFLSupportErrorCodeInvalidArgumentError +- description:@"offset + size exceeds the maximum size of the source buffer."]; ++ [TFLCommonUtils createCustomError:error ++ withCode:TFLSupportErrorCodeInvalidArgumentError ++ description:@"offset + size exceeds the maximum size " ++ @"of the source buffer."]; + return NO; + } + +@@ -51,13 +51,15 @@ + newOffset = offset + (size - _buffer.size); + } + +- // If the new nextIndex + sizeToCopy is smaller than the size of the ring buffer directly +- // copy all elements to the end of the ring buffer. ++ // If the new nextIndex + sizeToCopy is smaller than the size of the ring ++ // buffer directly copy all elements to the end of the ring buffer. + if (_nextIndex + sizeToCopy < _buffer.size) { +- memcpy(_buffer.data + _nextIndex, sourceBuffer.data + newOffset, sizeof(float) * sizeToCopy); ++ memcpy(_buffer.data + _nextIndex, sourceBuffer.data + newOffset, ++ sizeof(float) * sizeToCopy); + } else { + NSUInteger endChunkSize = _buffer.size - _nextIndex; +- memcpy(_buffer.data + _nextIndex, sourceBuffer.data + newOffset, sizeof(float) * endChunkSize); ++ memcpy(_buffer.data + _nextIndex, sourceBuffer.data + newOffset, ++ sizeof(float) * endChunkSize); + + NSUInteger startChunkSize = sizeToCopy - endChunkSize; + memcpy(_buffer.data, sourceBuffer.data + newOffset + endChunkSize, +@@ -69,16 +71,17 @@ + return YES; + } + +-- (TFLFloatBuffer *)floatBuffer { ++- (TFLFloatBuffer*)floatBuffer { + return [self floatBufferWithOffset:0 size:self.size]; + } + +-- (nullable TFLFloatBuffer *)floatBufferWithOffset:(NSUInteger)offset size:(NSUInteger)size { ++- (nullable TFLFloatBuffer*)floatBufferWithOffset:(NSUInteger)offset ++ size:(NSUInteger)size { + if (offset + size > _buffer.size) { + return nil; + } + +- TFLFloatBuffer *bufferToReturn = [[TFLFloatBuffer alloc] initWithSize:size]; ++ TFLFloatBuffer* bufferToReturn = [[TFLFloatBuffer alloc] initWithSize:size]; + + // Return buffer in correct order. + // Compute offset in flat ring buffer array considering warping. +@@ -86,17 +89,21 @@ + + // If no; elements to be copied are within the end of the flat ring buffer. + if ((correctOffset + size) <= _buffer.size) { +- memcpy(bufferToReturn.data, _buffer.data + correctOffset, sizeof(float) * size); ++ memcpy(bufferToReturn.data, _buffer.data + correctOffset, ++ sizeof(float) * size); + } else { +- // If no; elements to be copied warps around to the beginning of the ring buffer. +- // Copy the chunk starting at ringBuffer[nextIndex + offset : size] to +- // beginning of the result array. ++ // If no; elements to be copied warps around to the beginning of the ring ++ // buffer. Copy the chunk starting at ringBuffer[nextIndex + offset : size] ++ // to beginning of the result array. + NSInteger endChunkSize = _buffer.size - correctOffset; +- memcpy(bufferToReturn.data, _buffer.data + correctOffset, sizeof(float) * endChunkSize); ++ memcpy(bufferToReturn.data, _buffer.data + correctOffset, ++ sizeof(float) * endChunkSize); + +- // Next copy the chunk starting at ringBuffer[0 : size - endChunkSize] to the result array. ++ // Next copy the chunk starting at ringBuffer[0 : size - endChunkSize] to ++ // the result array. + NSInteger firstChunkSize = size - endChunkSize; +- memcpy(bufferToReturn.data + endChunkSize, _buffer.data, sizeof(float) * firstChunkSize); ++ memcpy(bufferToReturn.data + endChunkSize, _buffer.data, ++ sizeof(float) * firstChunkSize); + } + + return bufferToReturn; diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions+Helpers.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions+Helpers.h index a117bd7b3c4c3..5058f7c9a5a7b 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions+Helpers.h @@ -8740,7 +9596,7 @@ NS_ASSUME_NONNULL_END diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions.h -index cdcddabe7323a..0f92dd1005631 100644 +index 330132f4ba138..7ab7e7240791e 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions.h @@ -19,10 +19,10 @@ NS_ASSUME_NONNULL_BEGIN @@ -8756,7 +9612,7 @@ + * @discussion This property hould be greater than 0 or equal to -1. Setting it + * to -1 has the effect to let TFLite runtime set the value. */ - @property(nonatomic, assign) int numThreads; + @property(nonatomic) int numThreads; @@ -35,7 +35,7 @@ NS_SWIFT_NAME(ComputeSettings) @interface TFLComputeSettings : NSObject <NSCopying> @@ -8824,12 +9680,12 @@ NS_ASSUME_NONNULL_END diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLCategory+Helpers.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLCategory+Helpers.m -index 12db7c866f1eb..7a49281cea9fb 100644 +index 7d49c36aa48c9..4139525500a59 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLCategory+Helpers.m +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLCategory+Helpers.m -@@ -21,8 +21,8 @@ - - TFLCategory *category = [[TFLCategory alloc] init]; +@@ -19,8 +19,8 @@ + + (TFLCategory *)categoryWithCCategory:(TfLiteCategory *)cCategory { + if (cCategory == nil) return nil; - NSString *displayName; - NSString *label; @@ -8838,7 +9694,7 @@ if (cCategory->display_name != nil) { displayName = [NSString stringWithCString:cCategory->display_name -@@ -30,7 +30,8 @@ +@@ -28,7 +28,8 @@ } if (cCategory->label != nil) { @@ -8849,41 +9705,49 @@ return [[TFLCategory alloc] initWithIndex:(NSInteger)cCategory->index diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLCategory.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLCategory.h -index 17abb9888cc18..b5b19af5a91a8 100644 +index 91060ef4f1840..5c521f2239ab7 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLCategory.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLCategory.h -@@ -19,24 +19,24 @@ NS_ASSUME_NONNULL_BEGIN - /** Encapsulates information about a class in the classification results. */ +@@ -20,24 +20,25 @@ NS_ASSUME_NONNULL_BEGIN + NS_SWIFT_NAME(ClassificationCategory) @interface TFLCategory : NSObject -/** Index of the class in the corresponding label map, usually packed in the TFLite Model - * Metadata. */ +/** Index of the class in the corresponding label map, usually packed in the + * TFLite Model Metadata. */ - @property(nonatomic, assign, readonly) NSInteger index; + @property(nonatomic, readonly) NSInteger index; /** Confidence score for this class . */ - @property(nonatomic, assign, readonly) float score; + @property(nonatomic, readonly) float score; /** Class name of the class. */ --@property(nonatomic, copy, readonly, nullable) NSString *label; -+@property(nonatomic, copy, readonly, nullable) NSString* label; +-@property(nonatomic, readonly, nullable) NSString *label; ++@property(nonatomic, readonly, nullable) NSString* label; /** Display name of the class. */ --@property(nonatomic, copy, readonly, nullable) NSString *displayName; -+@property(nonatomic, copy, readonly, nullable) NSString* displayName; +-@property(nonatomic, readonly, nullable) NSString *displayName; ++@property(nonatomic, readonly, nullable) NSString* displayName; /** - * Initializes TFLCategory. +- * Initializes a new `TFLCategory` with the given index, score, label and display name. ++ * Initializes a new `TFLCategory` with the given index, score, label and ++ * display name. * - * @param index Index of the class in the corresponding label map, usually packed in the TFLite - * Model Metadata. + * @param index Index of the class in the corresponding label map, usually + * packed in the TFLite Model Metadata. * - * @param score Confidence score for this class . + * @param score Confidence score for this class. * -@@ -49,8 +49,8 @@ NS_ASSUME_NONNULL_BEGIN +@@ -45,12 +46,13 @@ NS_SWIFT_NAME(ClassificationCategory) + * + * @param displayName Display name of the class. + * +- * @return An instance of `TFLCategory` initialized with the given index, score, label and display name. ++ * @return An instance of `TFLCategory` initialized with the given index, score, ++ * label and display name. */ - (instancetype)initWithIndex:(NSInteger)index score:(float)score @@ -8892,7 +9756,7 @@ + label:(nullable NSString*)label + displayName:(nullable NSString*)displayName; - @end + - (instancetype)init NS_UNAVAILABLE; diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLCategory.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLCategory.m index b72c3b55fdaf1..603c5a27c9673 100644 @@ -8910,7 +9774,7 @@ if (self) { _index = index; diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions+Helpers.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions+Helpers.h -index 7627aeaa4c394..33ecd084382fc 100644 +index b12c118e89021..152aa33dbdb59 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions+Helpers.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions+Helpers.h @@ -18,11 +18,11 @@ @@ -8922,14 +9786,14 @@ +- (BOOL)copyToCOptions:(TfLiteClassificationOptions*)cClassificationOptions + error:(NSError**)error; - - (void)deleteCStringArraysOfClassificationOptions: + - (void)deleteAllocatedMemoryOfClassificationOptions: - (TfLiteClassificationOptions *)cClassificationOptions; + (TfLiteClassificationOptions*)cClassificationOptions; @end NS_ASSUME_NONNULL_END diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions+Helpers.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions+Helpers.m -index 6307bfe67731b..1d554caac4974 100644 +index 84e8fa5e234fb..767e5e4d577a3 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions+Helpers.m +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions+Helpers.m @@ -20,21 +20,28 @@ @@ -8989,11 +9853,20 @@ return NO; } } +@@ -93,7 +102,7 @@ + } + + - (void)deleteAllocatedMemoryOfClassificationOptions: +- (TfLiteClassificationOptions *)cClassificationOptions { ++ (TfLiteClassificationOptions*)cClassificationOptions { + if (self.labelAllowList) { + [TFLClassificationOptions deleteCStringsArray:cClassificationOptions->label_allowlist.list + count:cClassificationOptions->label_allowlist.length]; diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions.h -index 8282f9a273718..8e520686edc4c 100644 +index 41b69bec8a7d8..ce3f5d6580913 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions.h -@@ -22,13 +22,14 @@ NS_ASSUME_NONNULL_BEGIN +@@ -23,13 +23,14 @@ NS_SWIFT_NAME(ClassificationOptions) @interface TFLClassificationOptions : NSObject <NSCopying> /** If set, all classes in this list will be filtered out from the results . */ @@ -9011,18 +9884,18 @@ +@property(nonatomic, copy) NSString* displayNamesLocale; /** Results with score threshold greater than this value are returned . */ - @property(nonatomic, assign) float scoreThreshold; + @property(nonatomic) float scoreThreshold; diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult+Helpers.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult+Helpers.h -index 829d56def82ca..9c8b92c362806 100644 +index 7ef58fc5b76ce..351e87db729c6 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult+Helpers.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult+Helpers.h @@ -20,17 +20,18 @@ NS_ASSUME_NONNULL_BEGIN @interface TFLClassificationResult (Helpers) /** -- * Creates and retrurns a TFLClassificationResult from a TfLiteClassificationResult returned by +- * Creates and returns a TFLClassificationResult from a TfLiteClassificationResult returned by - * TFLite Task C Library Classification tasks. -+ * Creates and retrurns a TFLClassificationResult from a ++ * Creates and returns a TFLClassificationResult from a + * TfLiteClassificationResult returned by TFLite Task C Library Classification + * tasks. * @@ -9044,10 +9917,17 @@ NS_ASSUME_NONNULL_END diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult+Helpers.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult+Helpers.m -index ad604daf5daeb..1083e60d66a08 100644 +index c8744a3bf99c6..52e92852d88a9 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult+Helpers.m +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult+Helpers.m -@@ -24,17 +24,18 @@ +@@ -19,30 +19,34 @@ + + + (TFLClassificationResult *)classificationResultWithCResult: + (TfLiteClassificationResult *)cClassificationResult { +- if (!cClassificationResult) return nil; ++ if (!cClassificationResult) ++ return nil; + NSMutableArray *classificationHeads = [[NSMutableArray alloc] init]; for (int i = 0; i < cClassificationResult->size; i++) { TfLiteClassifications cClassifications = cClassificationResult->classifications[i]; @@ -9057,10 +9937,24 @@ TfLiteCategory cCategory = cClassifications.categories[j]; [categories addObject:[TFLCategory categoryWithCCategory:&cCategory]]; } -- TFLClassifications *classifications = [[TFLClassifications alloc] initWithHeadIndex:i + +- NSString *headName = nil; ++ NSString* headName = nil; + + if (cClassifications.head_name) { +- headName = [NSString stringWithCString:cClassifications.head_name encoding:NSUTF8StringEncoding]; ++ headName = [NSString stringWithCString:cClassifications.head_name ++ encoding:NSUTF8StringEncoding]; + } +- +- TFLClassifications *classifications = [[TFLClassifications alloc] initWithHeadIndex:cClassifications.head_index +- headName:headName - categories:categories]; -+ TFLClassifications* classifications = -+ [[TFLClassifications alloc] initWithHeadIndex:i categories:categories]; ++ ++ TFLClassifications* classifications = [[TFLClassifications alloc] ++ initWithHeadIndex:cClassifications.head_index ++ headName:headName ++ categories:categories]; [classificationHeads addObject:classifications]; } @@ -9071,16 +9965,17 @@ } @end diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult.h -index 35b62de68720b..80b1aafbc1742 100644 +index 72d5c85dec0d6..052b4f1daf710 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult.h -@@ -17,52 +17,57 @@ limitations under the License. +@@ -17,58 +17,66 @@ limitations under the License. NS_ASSUME_NONNULL_BEGIN -/** Encapsulates list of predicted classes (aka labels) for a given image classifier head. */ +/** Encapsulates list of predicted classes (aka labels) for a given image + * classifier head. */ + NS_SWIFT_NAME(Classifications) @interface TFLClassifications : NSObject /** @@ -9089,34 +9984,85 @@ + * The index of the image classifier head these classes refer to. This is useful + * for multi-head models. */ - @property(nonatomic, assign, readonly) NSInteger headIndex; + @property(nonatomic, readonly) NSInteger headIndex; + + /** The name of the classifier head, which is the corresponding tensor metadata +- * name. See https://github.com/tensorflow/tflite-support/blob/710e323265bfb71fdbdd72b3516e00cff15c0326/tensorflow_lite_support/metadata/metadata_schema.fbs#L545 +- * This will always be NULL for the `TFLClassifications` in the `TFLClassificationResult` returned by the follwing methods of `TFLImageClassifier`. ++ * name. See ++ * https://github.com/tensorflow/tflite-support/blob/710e323265bfb71fdbdd72b3516e00cff15c0326/tensorflow_lite_support/metadata/metadata_schema.fbs#L545 ++ * This will always be NULL for the `TFLClassifications` in the ++ * `TFLClassificationResult` returned by the follwing methods of ++ * `TFLImageClassifier`. + * 1. -[TFLImageClassifier classifyWithGMLImage:error:] + * 2. -[TFLImageClassifier classifyWithGMLImage:regionOfInterest:error:] + */ +-@property(nonatomic, readonly) NSString *headName; ++@property(nonatomic, readonly) NSString* headName; -/** The array of predicted classes, usually sorted by descending scores (e.g.from high to low - * probability). */ --@property(nonatomic, copy, readonly) NSArray<TFLCategory *> *categories; +-@property(nonatomic, readonly) NSArray<TFLCategory *> *categories; +/** The array of predicted classes, usually sorted by descending scores + * (e.g.from high to low probability). */ -+@property(nonatomic, copy, readonly) NSArray<TFLCategory*>* categories; ++@property(nonatomic, readonly) NSArray<TFLCategory*>* categories; /** - * Initializes TFLClassifications. +- * Initializes a new `TFLClassifications` with the given head index and array of categories. +- * head name is initialized to `nil`. ++ * Initializes a new `TFLClassifications` with the given head index and array of ++ * categories. head name is initialized to `nil`. * - * @param categories Array of TFLCategory objects encapsulating a list of +- * @param headIndex The index of the image classifier head these classes refer to. ++ * @param headIndex The index of the image classifier head these classes refer ++ * to. + * @param categories An array of `TFLCategory` objects encapsulating a list of - * predictions usually sorted by descending scores (e.g. from high to low probability). + * predictions usually sorted by descending scores (e.g. from high to low + * probability). - * @seealso TFLCategory * - * @return An instance of TFLClassifications initialized to - * the specified values. +- * @return An instance of `TFLClassifications` initialized with the given head index and +- * array of categories. ++ * @return An instance of `TFLClassifications` initialized with the given head ++ * index and array of categories. */ - (instancetype)initWithHeadIndex:(NSInteger)headIndex - categories:(NSArray<TFLCategory *> *)categories; +- ++ categories:(NSArray<TFLCategory*>*)categories; + + /** +- * Initializes a new `TFLClassifications` with the given head index, head name and array of categories. ++ * Initializes a new `TFLClassifications` with the given head index, head name ++ * and array of categories. + * +- * @param headIndex The index of the image classifier head these classes refer to. +- * @param headName The name of the classifier head, which is the corresponding tensor metadata +- * name. ++ * @param headIndex The index of the image classifier head these classes refer ++ * to. ++ * @param headName The name of the classifier head, which is the corresponding ++ * tensor metadata name. + * @param categories An array of `TFLCategory` objects encapsulating a list of +- * predictions usually sorted by descending scores (e.g. from high to low probability). ++ * predictions usually sorted by descending scores (e.g. from high to low ++ * probability). + * +- * @return An object of `TFLClassifications` initialized with the given head index, head name and +- * array of categories. ++ * @return An object of `TFLClassifications` initialized with the given head ++ * index, head name and array of categories. + */ + - (instancetype)initWithHeadIndex:(NSInteger)headIndex +- headName:(nullable NSString *)headName +- categories:(NSArray<TFLCategory *> *)categories; ++ headName:(nullable NSString*)headName + categories:(NSArray<TFLCategory*>*)categories; @end - /** Encapsulates results of any classification task. */ +@@ -76,20 +84,23 @@ NS_SWIFT_NAME(Classifications) + NS_SWIFT_NAME(ClassificationResult) @interface TFLClassificationResult : NSObject -/** Array of TFLClassifications objects containing image classifier predictions per image classifier @@ -9124,22 +10070,22 @@ +/** Array of TFLClassifications objects containing image classifier predictions + * per image classifier head. */ --@property(nonatomic, copy, readonly) NSArray<TFLClassifications *> *classifications; -+@property(nonatomic, copy, readonly) -+ NSArray<TFLClassifications*>* classifications; +-@property(nonatomic, readonly) NSArray<TFLClassifications *> *classifications; ++@property(nonatomic, readonly) NSArray<TFLClassifications*>* classifications; /** - * Initializes TFLClassificationResult. +- * Initializes a new `TFLClassificationResult` with the given array of classifications. ++ * Initializes a new `TFLClassificationResult` with the given array of ++ * classifications. * -- * @param classifications Array of TFLClassifications objects containing image classifier +- * @param classifications An Aaray of `TFLClassifications` objects containing image classifier - * predictions per image classifier head. -+ * @param classifications Array of TFLClassifications objects containing image -+ * classifier predictions per image classifier head. - * @seealso TFLClassifications ++ * @param classifications An Aaray of `TFLClassifications` objects containing ++ * image classifier predictions per image classifier head. * -- * @return An instance of TFLClassificationResult initialized to the specified values. -+ * @return An instance of TFLClassificationResult initialized to the specified -+ * values. +- * @return An instance of 1TFLClassificationResult1 initialized with the given array of classifications. ++ * @return An instance of 1TFLClassificationResult1 initialized with the given ++ * array of classifications. */ -- (instancetype)initWithClassifications:(NSArray<TFLClassifications *> *)classifications; +- (instancetype)initWithClassifications: @@ -9148,19 +10094,30 @@ @end diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult.m -index f118f40c064c2..b2ab012a4c899 100644 +index f56600cb94f3b..0ea238417c891 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult.m +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult.m -@@ -17,7 +17,7 @@ limitations under the License. +@@ -17,9 +17,8 @@ limitations under the License. @implementation TFLClassifications - (instancetype)initWithHeadIndex:(NSInteger)headIndex +- headName:(nullable NSString *)headName - categories:(NSArray<TFLCategory *> *)categories { +- ++ headName:(nullable NSString*)headName + categories:(NSArray<TFLCategory*>*)categories { self = [super init]; if (self) { _headIndex = headIndex; -@@ -29,10 +29,11 @@ limitations under the License. +@@ -30,17 +29,18 @@ limitations under the License. + } + + - (instancetype)initWithHeadIndex:(NSInteger)headIndex +- categories:(NSArray<TFLCategory *> *)categories { ++ categories:(NSArray<TFLCategory*>*)categories { + return [self initWithHeadIndex:headIndex headName:nil categories:categories]; + } + @end @implementation TFLClassificationResult { @@ -9175,10 +10132,10 @@ if (self) { _classifications = classifications; diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult+Helpers.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult+Helpers.h -index ff2c546a884cc..81efbcc1d8c57 100644 +index 7f6e8cae27f2c..81efbcc1d8c57 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult+Helpers.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult+Helpers.h -@@ -19,17 +19,17 @@ NS_ASSUME_NONNULL_BEGIN +@@ -19,16 +19,17 @@ NS_ASSUME_NONNULL_BEGIN @interface TFLDetectionResult (Helpers) /** @@ -9195,24 +10152,57 @@ + * @return Detection Result of type TFLDetectionResult to be returned by + * inference methods of the iOS TF Lite Task Object Detection task. */ --+ (TFLDetectionResult *)detectionResultWithCResult: -- (TfLiteDetectionResult *)cDetectionResult; +-+ (TFLDetectionResult *)detectionResultWithCResult:(TfLiteDetectionResult *)cDetectionResult; ++ (TFLDetectionResult*)detectionResultWithCResult: + (TfLiteDetectionResult*)cDetectionResult; @end NS_ASSUME_NONNULL_END +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult+Helpers.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult+Helpers.m +index 405bddf117cdd..3ae292cb0ef3b 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult+Helpers.m ++++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult+Helpers.m +@@ -17,8 +17,10 @@ + + @implementation TFLDetectionResult (Helpers) + +-+ (TFLDetectionResult *)detectionResultWithCResult:(TfLiteDetectionResult *)cDetectionResult { +- if (!cDetectionResult) return nil; +++ (TFLDetectionResult*)detectionResultWithCResult: ++ (TfLiteDetectionResult*)cDetectionResult { ++ if (!cDetectionResult) ++ return nil; + + NSMutableArray *detections = [[NSMutableArray alloc] init]; + for (int i = 0; i < cDetectionResult->size; i++) { +@@ -30,10 +32,11 @@ + TFLCategory *resultCategory = [TFLCategory categoryWithCCategory:&cCategory]; + [categories addObject:resultCategory]; + } +- TFLDetection *detection = [[TFLDetection alloc] +- initWithBoundingBox:CGRectMake( +- cDetection.bounding_box.origin_x, cDetection.bounding_box.origin_y, +- cDetection.bounding_box.width, cDetection.bounding_box.height) ++ TFLDetection* detection = [[TFLDetection alloc] ++ initWithBoundingBox:CGRectMake(cDetection.bounding_box.origin_x, ++ cDetection.bounding_box.origin_y, ++ cDetection.bounding_box.width, ++ cDetection.bounding_box.height) + categories:categories]; + [detections addObject:detection]; + } diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult.h -index d5b97506ec508..4d7fc0c0503fa 100644 +index 0c64aa98b6089..00cc75bbc161e 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult.h -@@ -19,25 +19,26 @@ limitations under the License. +@@ -19,31 +19,35 @@ limitations under the License. NS_ASSUME_NONNULL_BEGIN -/** Encapsulates list of predicted classes (aka labels) and bounding box for a detected object. */ +/** Encapsulates list of predicted classes (aka labels) and bounding box for a + * detected object. */ + NS_SWIFT_NAME(Detection) @interface TFLDetection : NSObject /** @@ -9221,25 +10211,81 @@ + * The index of the image classifier head these classes refer to. This is useful + * for multi-head models. */ - @property(nonatomic, assign) CGRect boundingBox; + @property(nonatomic, readonly) CGRect boundingBox; -/** The array of predicted classes, usually sorted by descending scores (e.g.from high to low - * probability). */ --@property(nonatomic, copy) NSArray<TFLCategory *> *categories; +-@property(nonatomic, readonly) NSArray<TFLCategory *> *categories; +/** The array of predicted classes, usually sorted by descending scores + * (e.g.from high to low probability). */ -+@property(nonatomic, copy) NSArray<TFLCategory*>* categories; ++@property(nonatomic, readonly) NSArray<TFLCategory*>* categories; - @end + /** +- * Initializes an object of `TFLDetection` with the given bounding box and array of categories. ++ * Initializes an object of `TFLDetection` with the given bounding box and array ++ * of categories. + * +- * @param boundingBox CGRect specifying the bounds of the object represented by this detection. +- * @param categories Array of predicted classes, usually sorted by descending scores (e.g.from high +- * to low probability). ++ * @param boundingBox CGRect specifying the bounds of the object represented by ++ * this detection. ++ * @param categories Array of predicted classes, usually sorted by descending ++ * scores (e.g.from high to low probability). + * +- * @return An instance of `TFLDetection` initialized with the given bounding box and array of categories. ++ * @return An instance of `TFLDetection` initialized with the given bounding box ++ * and array of categories. + */ + - (instancetype)initWithBoundingBox:(CGRect)boundingBox +- categories:(NSArray<TFLCategory *> *)categories; ++ categories:(NSArray<TFLCategory*>*)categories; - /** Encapsulates results of any object detection task. */ + - (instancetype)init NS_UNAVAILABLE; + +@@ -55,16 +59,17 @@ NS_SWIFT_NAME(Detection) + NS_SWIFT_NAME(DetectionResult) @interface TFLDetectionResult : NSObject --@property(nonatomic, copy) NSArray<TFLDetection *> *detections; -+@property(nonatomic, copy) NSArray<TFLDetection*>* detections; +-@property(nonatomic, readonly) NSArray<TFLDetection *> *detections; ++@property(nonatomic, readonly) NSArray<TFLDetection*>* detections; - @end + /** + * Initializes a new `TFLDetectionResult` with the given array of detections. + * + * @param detections Array of detected objects of type TFLDetection. + * +- * @return An instance of `TFLDetectionResult` initialized with the given array of detections. ++ * @return An instance of `TFLDetectionResult` initialized with the given array ++ * of detections. + */ +-- (instancetype)initWithDetections:(NSArray<TFLDetection *> *)detections; ++- (instancetype)initWithDetections:(NSArray<TFLDetection*>*)detections; + - (instancetype)init NS_UNAVAILABLE; + +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult.m +index 280767e6a353a..14cec3bca3d08 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult.m ++++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult.m +@@ -17,7 +17,7 @@ limitations under the License. + @implementation TFLDetection + + - (instancetype)initWithBoundingBox:(CGRect)boundingBox +- categories:(NSArray<TFLCategory *> *)categories { ++ categories:(NSArray<TFLCategory*>*)categories { + self = [super init]; + if (self) { + _boundingBox = boundingBox; +@@ -30,7 +30,7 @@ limitations under the License. + + @implementation TFLDetectionResult + +-- (instancetype)initWithDetections:(NSArray<TFLDetection *> *)detections { ++- (instancetype)initWithDetections:(NSArray<TFLDetection*>*)detections { + self = [super init]; + if (self) { + _detections = detections; diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult+Helpers.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult+Helpers.h index c979fda53c70b..0a85efe2877bb 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult+Helpers.h @@ -9256,19 +10302,19 @@ NS_ASSUME_NONNULL_END diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult+Helpers.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult+Helpers.m -index 8249fc3c95e93..b531e78df2d82 100644 +index f2ea957ca3010..2a897f0ba3614 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult+Helpers.m +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult+Helpers.m -@@ -16,25 +16,27 @@ +@@ -16,29 +16,31 @@ @implementation TFLSegmentationResult (Helpers) -+ (TFLSegmentationResult *)segmentationResultWithCResult: - (TfLiteSegmentationResult *)cSegmentationResult { -- if (cSegmentationResult == nil) return nil; +- if (!cSegmentationResult) return nil; ++ (TFLSegmentationResult*)segmentationResultWithCResult: + (TfLiteSegmentationResult*)cSegmentationResult { -+ if (cSegmentationResult == nil) ++ if (!cSegmentationResult) + return nil; - NSMutableArray *segmentations = [[NSMutableArray alloc] init]; @@ -9280,28 +10326,32 @@ for (int j = 0; j < cSegmentation.colored_labels_size; j++) { TfLiteColoredLabel cColoredLabel = cSegmentation.colored_labels[j]; -- TFLColoredLabel *coloredLabel = [[TFLColoredLabel alloc] init]; -+ TFLColoredLabel* coloredLabel = [[TFLColoredLabel alloc] init]; - coloredLabel.r = (NSUInteger)cColoredLabel.r; - coloredLabel.g = (NSUInteger)cColoredLabel.g; - coloredLabel.b = (NSUInteger)cColoredLabel.b; - - if (cColoredLabel.display_name != nil) { -- coloredLabel.displayName = [NSString stringWithCString:cColoredLabel.display_name -- encoding:NSUTF8StringEncoding]; -+ coloredLabel.displayName = -+ [NSString stringWithCString:cColoredLabel.display_name -+ encoding:NSUTF8StringEncoding]; +- NSString *displayName; ++ NSString* displayName; + if (cColoredLabel.display_name) { + displayName = [NSString stringWithCString:cColoredLabel.display_name + encoding:NSUTF8StringEncoding]; } - if (cColoredLabel.label != nil) { -@@ -45,16 +47,16 @@ +- NSString *label; ++ NSString* label; + if (cColoredLabel.label) { +- label = [NSString stringWithCString:cColoredLabel.label encoding:NSUTF8StringEncoding]; ++ label = [NSString stringWithCString:cColoredLabel.label ++ encoding:NSUTF8StringEncoding]; + } + +- TFLColoredLabel *coloredLabel = ++ TFLColoredLabel* coloredLabel = + [[TFLColoredLabel alloc] initWithRed:(NSUInteger)cColoredLabel.r + green:(NSUInteger)cColoredLabel.g + blue:(NSUInteger)cColoredLabel.b +@@ -47,27 +49,29 @@ [coloredLabels addObject:coloredLabel]; } -- TFLSegmentation *segmentation = [[TFLSegmentation alloc] init]; -+ TFLSegmentation* segmentation = [[TFLSegmentation alloc] init]; - segmentation.coloredLabels = coloredLabels; +- TFLSegmentation *segmentation; ++ TFLSegmentation* segmentation; if (cSegmentation.confidence_masks) { - NSMutableArray *confidenceMasks = [[NSMutableArray alloc] init]; @@ -9317,22 +10367,31 @@ + mask:cSegmentation.confidence_masks[i]]; [confidenceMasks addObject:confidenceMask]; } - segmentation.confidenceMasks = confidenceMasks; -@@ -69,7 +71,8 @@ - [segmentations addObject:segmentation]; - } +- segmentation = [[TFLSegmentation alloc] initWithConfidenceMasks:confidenceMasks +- coloredLabels:coloredLabels]; ++ segmentation = ++ [[TFLSegmentation alloc] initWithConfidenceMasks:confidenceMasks ++ coloredLabels:coloredLabels]; -- TFLSegmentationResult *segmentationResult = [[TFLSegmentationResult alloc] init]; -+ TFLSegmentationResult* segmentationResult = -+ [[TFLSegmentationResult alloc] init]; - segmentationResult.segmentations = segmentations; - return segmentationResult; - } + } else if (cSegmentation.category_mask) { +- TFLCategoryMask *categoryMask = ++ TFLCategoryMask* categoryMask = + [[TFLCategoryMask alloc] initWithWidth:(NSInteger)cSegmentation.width + height:(NSInteger)cSegmentation.height + mask:cSegmentation.category_mask]; +- segmentation = [[TFLSegmentation alloc] initWithCategoryMask:categoryMask +- coloredLabels:coloredLabels]; ++ segmentation = ++ [[TFLSegmentation alloc] initWithCategoryMask:categoryMask ++ coloredLabels:coloredLabels]; + } + + [segmentations addObject:segmentation]; diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult.h -index 463ef3bb57bad..49abd3f25aa41 100644 +index 1307e26294dd4..3aca4567ebe2e 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult.h -@@ -22,7 +22,7 @@ NS_ASSUME_NONNULL_BEGIN +@@ -23,7 +23,7 @@ NS_SWIFT_NAME(ConfidenceMask) /** * Confidence masks of size `width` x `height` for any one class. */ @@ -9341,16 +10400,16 @@ /** * The width of the mask. This is an intrinsic parameter of the model being -@@ -41,7 +41,7 @@ NS_ASSUME_NONNULL_BEGIN +@@ -42,7 +42,7 @@ NS_SWIFT_NAME(ConfidenceMask) */ - (instancetype)initWithWidth:(NSInteger)width height:(NSInteger)height - mask:(float * _Nullable)mask; + mask:(float* _Nullable)mask; - @end + - (instancetype)init NS_UNAVAILABLE; -@@ -53,7 +53,7 @@ NS_ASSUME_NONNULL_BEGIN +@@ -59,7 +59,7 @@ NS_SWIFT_NAME(CategoryMask) * The value of each pixel in this mask represents the class to which the * pixel belongs. */ @@ -9359,78 +10418,163 @@ /** * The width of the mask. This is an intrinsic parameter of the model being -@@ -72,7 +72,7 @@ NS_ASSUME_NONNULL_BEGIN +@@ -80,15 +80,15 @@ NS_SWIFT_NAME(CategoryMask) + * + * @param width Width of the mask. + * @param height Height of the mask. +- * @param mask Flattened 2D-array of size `width` x `height`, in row major order. +- * The value of each pixel in this mask represents the class to which the ++ * @param mask Flattened 2D-array of size `width` x `height`, in row major ++ * order. The value of each pixel in this mask represents the class to which the + * pixel belongs. + * + * @return An instance of TFLCategoryMask initialized to the specified values. */ - (instancetype)initWithWidth:(NSInteger)width height:(NSInteger)height - mask:(UInt8 * _Nullable)mask; + mask:(UInt8* _Nullable)mask; - @end + - (instancetype)init NS_UNAVAILABLE; -@@ -87,13 +87,13 @@ NS_ASSUME_NONNULL_BEGIN - /** The class name, as provided in the label map packed in the TFLite Model +@@ -107,17 +107,18 @@ NS_SWIFT_NAME(ColoredLabel) + * The class name, as provided in the label map packed in the TFLite Model * Metadata. */ --@property(nonatomic, copy) NSString *label; -+@property(nonatomic, copy) NSString* label; +-@property(nonatomic, readonly) NSString *label; ++@property(nonatomic, readonly) NSString* label; - /** The display name, as provided in the label map (if available) packed in - * the TFLite Model Metadata. See `display_names_locale` field in - * ImageSegmenterOptions. + /** + * The display name, as provided in the label map (if available) packed in + * the TFLite Model Metadata. See displayNamesLocale in + * TFLClassificationOptions. */ --@property(nonatomic, copy) NSString *displayName; -+@property(nonatomic, copy) NSString* displayName; +-@property(nonatomic, readonly) NSString *displayName; ++@property(nonatomic, readonly) NSString* displayName; - @end + /** +- * Initializes a new `TFLColoredLabel` with red, gree, blue color components, label and display name. ++ * Initializes a new `TFLColoredLabel` with red, gree, blue color components, ++ * label and display name. + * + * @param r Red component of the RGB color components. + * @param g Green component of the RGB color components. +@@ -125,13 +126,14 @@ NS_SWIFT_NAME(ColoredLabel) + * @param label Class name. + * @param displayName Display name. + * +- * @return An instance of TFLColoredLabel initialized with red, gree, blue color components, label and display name. ++ * @return An instance of TFLColoredLabel initialized with red, gree, blue color ++ * components, label and display name. + */ + - (instancetype)initWithRed:(NSUInteger)r + green:(NSUInteger)g + blue:(NSUInteger)b +- label:(NSString *)label +- displayName:(NSString *)displayName; ++ label:(NSString*)label ++ displayName:(NSString*)displayName; -@@ -107,14 +107,15 @@ NS_ASSUME_NONNULL_BEGIN + - (instancetype)init NS_UNAVAILABLE; + +@@ -150,7 +152,8 @@ NS_SWIFT_NAME(Segmentation) * this particular class. * This property is mutually exclusive with `categoryMask`. */ --@property(nonatomic, strong, nullable) NSArray<TFLConfidenceMask *> *confidenceMasks; -+@property(nonatomic, strong, nullable) +-@property(nonatomic, nullable, readonly) NSArray<TFLConfidenceMask *> *confidenceMasks; ++@property(nonatomic, nullable, readonly) + NSArray<TFLConfidenceMask*>* confidenceMasks; - /** Holds the category mask. - * The value of each pixel in this mask represents the class to which the + /** + * Holds the category mask. +@@ -158,7 +161,7 @@ NS_SWIFT_NAME(Segmentation) * pixel belongs. * This property is mutually exclusive with `confidenceMasks`. */ --@property(nonatomic, strong, nullable) TFLCategoryMask *categoryMask; -+@property(nonatomic, strong, nullable) TFLCategoryMask* categoryMask; +-@property(nonatomic, nullable, readonly) TFLCategoryMask *categoryMask; ++@property(nonatomic, nullable, readonly) TFLCategoryMask* categoryMask; /** * The list of colored labels for all the supported categories (classes). -@@ -123,7 +124,7 @@ NS_ASSUME_NONNULL_BEGIN +@@ -167,33 +170,38 @@ NS_SWIFT_NAME(Segmentation) * `colored_labels[i]`, `confidence_masks` indices, i.e. `confidence_masks[i]` * is associated with `colored_labels[i]`. */ --@property(nonatomic, strong) NSArray<TFLColoredLabel *> *coloredLabels; -+@property(nonatomic, strong) NSArray<TFLColoredLabel*>* coloredLabels; +-@property(nonatomic, readonly) NSArray<TFLColoredLabel *> *coloredLabels; ++@property(nonatomic, readonly) NSArray<TFLColoredLabel*>* coloredLabels; - @end + + (instancetype)new NS_UNAVAILABLE; -@@ -136,7 +137,7 @@ NS_ASSUME_NONNULL_BEGIN + /** +- * Initializes a new `TFLSegmentation` with an array of confidence masks and an array of colored labels. +- * `categoryMask` is initialized to `nil` as it is mutually exclusive with `confidenceMasks`. ++ * Initializes a new `TFLSegmentation` with an array of confidence masks and an ++ * array of colored labels. `categoryMask` is initialized to `nil` as it is ++ * mutually exclusive with `confidenceMasks`. + * + * @param confidenceMasks An array of `TFLConfidenceMask` objects. + * @param coloredLabels An array of `TFLColoredLabel` objects. + * +- * @return An instance of `TFLSegmentation` initialized with an array of confidence masks and an array of colored labels. ++ * @return An instance of `TFLSegmentation` initialized with an array of ++ * confidence masks and an array of colored labels. + */ +-- (instancetype)initWithConfidenceMasks:(NSArray<TFLConfidenceMask *> *)confidenceMasks +- coloredLabels:(NSArray<TFLColoredLabel *> *)coloredLabels; ++- (instancetype) ++ initWithConfidenceMasks:(NSArray<TFLConfidenceMask*>*)confidenceMasks ++ coloredLabels:(NSArray<TFLColoredLabel*>*)coloredLabels; + + /** +- * Initializes a new `TFLSegmentation` with a category mask and array of colored labels. +- * `confidenceMasks` is initialized to `nil` as it is mutually exclusive with `categoryMask`. ++ * Initializes a new `TFLSegmentation` with a category mask and array of colored ++ * labels. `confidenceMasks` is initialized to `nil` as it is mutually exclusive ++ * with `categoryMask`. + * + * @param categoryMask A `TFLCategoryMask` object. + * @param coloredLabels An array of `TFLColoredLabel` objects. + * +- * @return An instance of `TFLSegmentation` initialized with a category mask and array of colored labels. ++ * @return An instance of `TFLSegmentation` initialized with a category mask and ++ * array of colored labels. + */ +-- (instancetype)initWithCategoryMask:(TFLCategoryMask *)categoryMask +- coloredLabels:(NSArray<TFLColoredLabel *> *)coloredLabels; ++- (instancetype)initWithCategoryMask:(TFLCategoryMask*)categoryMask ++ coloredLabels:(NSArray<TFLColoredLabel*>*)coloredLabels; + + - (instancetype)init NS_UNAVAILABLE; + +@@ -209,7 +217,7 @@ NS_SWIFT_NAME(SegmentationResult) * e.g. instance segmentation models, which may return one segmentation per * object. */ --@property(nonatomic, strong) NSArray<TFLSegmentation *> *segmentations; -+@property(nonatomic, strong) NSArray<TFLSegmentation*>* segmentations; +-@property(nonatomic, readonly) NSArray<TFLSegmentation *> *segmentations; ++@property(nonatomic, readonly) NSArray<TFLSegmentation*>* segmentations; - @end + + (instancetype)new NS_UNAVAILABLE; + +@@ -218,9 +226,10 @@ NS_SWIFT_NAME(SegmentationResult) + * + * @param segmentations An array of `TFLSegmentation` objects. + * +- * @return An instance of `TFLSegmentationResult` initialized with an array of segmentations. ++ * @return An instance of `TFLSegmentationResult` initialized with an array of ++ * segmentations. + */ +-- (instancetype)initWithSegmentations:(NSArray<TFLSegmentation *> *)segmentations; ++- (instancetype)initWithSegmentations:(NSArray<TFLSegmentation*>*)segmentations; + + - (instancetype)init NS_UNAVAILABLE; diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult.m -index caa0162be8d26..bf0e75c8e1099 100644 +index 33defd1139509..45b5510525fdc 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult.m +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult.m -@@ -17,10 +17,12 @@ - @implementation TFLCategoryMask { - NSInteger _width; - NSInteger _height; -- UInt8 *_mask; -+ UInt8* _mask; - } +@@ -17,13 +17,16 @@ + + @implementation TFLCategoryMask -- (instancetype)initWithWidth:(NSInteger)width height:(NSInteger)height mask:(UInt8 *)mask { +- (instancetype)initWithWidth:(NSInteger)width @@ -9439,7 +10583,15 @@ self = [super init]; if (self) { _width = width; -@@ -33,7 +35,7 @@ + _height = height; + if (mask != NULL) { +- _mask = [TFLCommonUtils mallocWithSize:width * height * sizeof(UInt8) error:nil]; ++ _mask = [TFLCommonUtils mallocWithSize:width * height * sizeof(UInt8) ++ error:nil]; + if (_mask) { + memcpy(_mask, mask, width * height * sizeof(UInt8)); + } +@@ -32,7 +35,7 @@ return self; } @@ -9448,13 +10600,9 @@ return [[TFLCategoryMask alloc] initWithWidth:self.width height:self.height mask:self.mask]; -@@ -48,10 +50,12 @@ - @implementation TFLConfidenceMask { - NSInteger _width; - NSInteger _height; -- float *_mask; -+ float* _mask; - } +@@ -46,13 +49,16 @@ + + @implementation TFLConfidenceMask -- (instancetype)initWithWidth:(NSInteger)width height:(NSInteger)height mask:(float *)mask { +- (instancetype)initWithWidth:(NSInteger)width @@ -9463,7 +10611,15 @@ self = [super init]; if (self) { _width = width; -@@ -64,7 +68,7 @@ + _height = height; + if (mask != NULL) { +- _mask = [TFLCommonUtils mallocWithSize:width * height * sizeof(float) error:nil]; ++ _mask = [TFLCommonUtils mallocWithSize:width * height * sizeof(float) ++ error:nil]; + if (_mask) { + memcpy(_mask, mask, width * height * sizeof(float)); + } +@@ -61,7 +67,7 @@ return self; } @@ -9472,6 +10628,61 @@ return [[TFLConfidenceMask alloc] initWithWidth:self.width height:self.height mask:self.mask]; +@@ -78,8 +84,8 @@ + - (instancetype)initWithRed:(NSUInteger)r + green:(NSUInteger)g + blue:(NSUInteger)b +- label:(NSString *)label +- displayName:(NSString *)displayName { ++ label:(NSString*)label ++ displayName:(NSString*)displayName { + self = [super init]; + if (self) { + _r = r; +@@ -95,21 +101,25 @@ + + @implementation TFLSegmentation + +-- (instancetype)initWithConfidenceMasks:(NSArray<TFLConfidenceMask *> *)confidenceMasks +- coloredLabels:(NSArray<TFLColoredLabel *> *)coloredLabels { ++- (instancetype) ++ initWithConfidenceMasks:(NSArray<TFLConfidenceMask*>*)confidenceMasks ++ coloredLabels:(NSArray<TFLColoredLabel*>*)coloredLabels { + return [self initWithConfidenceMasks:confidenceMasks + categoryMask:nil + coloredLabels:coloredLabels]; + } + +-- (instancetype)initWithCategoryMask:(TFLCategoryMask *)categoryMask +- coloredLabels:(NSArray<TFLColoredLabel *> *)coloredLabels { +- return [self initWithConfidenceMasks:nil categoryMask:categoryMask coloredLabels:coloredLabels]; ++- (instancetype)initWithCategoryMask:(TFLCategoryMask*)categoryMask ++ coloredLabels:(NSArray<TFLColoredLabel*>*)coloredLabels { ++ return [self initWithConfidenceMasks:nil ++ categoryMask:categoryMask ++ coloredLabels:coloredLabels]; + } + +-- (instancetype)initWithConfidenceMasks:(NSArray<TFLConfidenceMask *> *)confidenceMasks +- categoryMask:(TFLCategoryMask *)categoryMask +- coloredLabels:(NSArray<TFLColoredLabel *> *)coloredLabels { ++- (instancetype) ++ initWithConfidenceMasks:(NSArray<TFLConfidenceMask*>*)confidenceMasks ++ categoryMask:(TFLCategoryMask*)categoryMask ++ coloredLabels:(NSArray<TFLColoredLabel*>*)coloredLabels { + self = [super init]; + if (self) { + _confidenceMasks = confidenceMasks; +@@ -123,7 +133,8 @@ + + @implementation TFLSegmentationResult + +-- (instancetype)initWithSegmentations:(NSArray<TFLSegmentation *> *)segmentations { ++- (instancetype)initWithSegmentations: ++ (NSArray<TFLSegmentation*>*)segmentations { + self = [super init]; + if (self) { + _segmentations = segmentations; diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.h index 99de5ad04febf..ac81a15ac11c6 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.h @@ -9637,100 +10848,167 @@ * @param question Question to ask. * diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageClassifier.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageClassifier.h -index 201fd9a40c54a..5befd570b9749 100644 +index f228034147c40..7e38abe002623 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageClassifier.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageClassifier.h -@@ -30,27 +30,29 @@ NS_ASSUME_NONNULL_BEGIN - * Base options that is used for creation of any type of task. - * @seealso TFLBaseOptions +@@ -31,29 +31,32 @@ NS_SWIFT_NAME(ImageClassifierOptions) + * Base options that are used for creation of any type of task. + * @discussion Please see `TFLBaseOptions` for more details. */ -@property(nonatomic, copy) TFLBaseOptions *baseOptions; +@property(nonatomic, copy) TFLBaseOptions* baseOptions; /** * Options that configure the display and filtering of results. - * @seealso TFLClassificationOptions + * @discussion Please see `TFLClassificationOptions` for more details. */ -@property(nonatomic, copy) TFLClassificationOptions *classificationOptions; +@property(nonatomic, copy) TFLClassificationOptions* classificationOptions; /** -- * Initializes TFLImageClassifierOptions with the model path set to the specified path to a model -- * file. -- * @description The external model file, must be a single standalone TFLite file. It could be packed +- * Initializes a new `TFLImageClassifierOptions` with the absolute path to the model file +- * stored locally on the device, set to the given the model path. ++ * Initializes a new `TFLImageClassifierOptions` with the absolute path to the ++ * model file stored locally on the device, set to the given the model path. + * +- * @discussion The external model file, must be a single standalone TFLite file. It could be packed - * with TFLite Model Metadata[1] and associated files if exist. Fail to provide the necessary - * metadata and associated files might result in errors. Check the [documentation] - * (https://www.tensorflow.org/lite/convert/metadata) for each task about the specific requirement. -+ * Initializes TFLImageClassifierOptions with the model path set to the -+ * specified path to a model file. -+ * @description The external model file, must be a single standalone TFLite -+ * file. It could be packed with TFLite Model Metadata[1] and associated files -+ * if exist. Fail to provide the necessary metadata and associated files might ++ * @discussion The external model file, must be a single standalone TFLite file. ++ * It could be packed with TFLite Model Metadata[1] and associated files if ++ * exist. Fail to provide the necessary metadata and associated files might + * result in errors. Check the [documentation] + * (https://www.tensorflow.org/lite/convert/metadata) for each task about the + * specific requirement. * - * @param modelPath Path to a TFLite model file. - * @return An instance of TFLImageClassifierOptions set to the specified - * modelPath. - */ --- (nullable instancetype)initWithModelPath:(NSString *)modelPath; -+- (nullable instancetype)initWithModelPath:(NSString*)modelPath; - - - (instancetype)init NS_UNAVAILABLE; - -@@ -71,8 +73,9 @@ NS_ASSUME_NONNULL_BEGIN +- * @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device. ++ * @param modelPath An absolute path to a TensorFlow Lite model file stored ++ * locally on the device. * - * @return A TFLImageClassifier instance. + * @return An instance of `TFLImageClassifierOptions` initialized to the given + * model path. + */ +-- (instancetype)initWithModelPath:(NSString *)modelPath; ++- (instancetype)initWithModelPath:(NSString*)modelPath; + + @end + +@@ -64,17 +67,19 @@ NS_SWIFT_NAME(ImageClassifier) + @interface TFLImageClassifier : NSObject + + /** +- * Creates a new instance of `TFLImageClassifier` from the given `TFLImageClassifierOptions`. ++ * Creates a new instance of `TFLImageClassifier` from the given ++ * `TFLImageClassifierOptions`. + * + * @param options The options to use for configuring the `TFLImageClassifier`. +- * @param error An optional error parameter populated when there is an error in initializing +- * the image classifier. ++ * @param error An optional error parameter populated when there is an error in ++ * initializing the image classifier. + * +- * @return A new instance of `TFLImageClassifier` with the given options. `nil` if there is an error +- * in initializing the image classifier. ++ * @return A new instance of `TFLImageClassifier` with the given options. `nil` ++ * if there is an error in initializing the image classifier. */ -+ (nullable instancetype)imageClassifierWithOptions:(TFLImageClassifierOptions *)options - error:(NSError **)error ++ (nullable instancetype)imageClassifierWithOptions: + (TFLImageClassifierOptions*)options + error:(NSError**)error - NS_SWIFT_NAME(imageClassifier(options:)); + NS_SWIFT_NAME(classifier(options:)); + + (instancetype)new NS_UNAVAILABLE; +@@ -82,46 +87,49 @@ NS_SWIFT_NAME(ImageClassifier) /** -@@ -92,8 +95,9 @@ NS_ASSUME_NONNULL_BEGIN - * @param image input to the model. - * @return An NSArray<NSArray<TFLClass *>*> * of classification results. + * Performs classification on the given GMLImage. + * +- * @discussion This method currently supports classification of only the following types of images: ++ * @discussion This method currently supports classification of only the ++ * following types of images: + * 1. RGB and RGBA images for `GMLImageSourceTypeImage`. + * 2. kCVPixelFormatType_32BGRA for `GMLImageSourceTypePixelBuffer` and +- * `GMLImageSourceTypeSampleBuffer`. If you are using `AVCaptureSession` to setup +- * camera and get the frames for inference, you must request for this format +- * from AVCaptureVideoDataOutput. Otherwise your classification +- * results will be wrong. ++ * `GMLImageSourceTypeSampleBuffer`. If you are using `AVCaptureSession` to ++ * setup camera and get the frames for inference, you must request for this ++ * format from AVCaptureVideoDataOutput. Otherwise your classification results ++ * will be wrong. + * + * @param image An image to be classified, represented as a `GMLImage`. + * +- * @return A TFLClassificationResult with one set of results per image classifier head. `nil` if +- * there is an error encountered during classification. Please see `TFLClassificationResult` for +- * more details. ++ * @return A TFLClassificationResult with one set of results per image ++ * classifier head. `nil` if there is an error encountered during ++ * classification. Please see `TFLClassificationResult` for more details. */ -- (nullable TFLClassificationResult *)classifyWithGMLImage:(GMLImage *)image -- error:(NSError *_Nullable *)error +- error:(NSError **)error +- (nullable TFLClassificationResult*)classifyWithGMLImage:(GMLImage*)image -+ error:(NSError* _Nullable*) -+ error - NS_SWIFT_NAME(classify(gmlImage:)); ++ error:(NSError**)error + NS_SWIFT_NAME(classify(mlImage:)); /** -@@ -107,9 +111,10 @@ NS_ASSUME_NONNULL_BEGIN +- * Performs classification on the pixels within the specified region of interest of the given +- * `GMLImage`. ++ * Performs classification on the pixels within the specified region of interest ++ * of the given `GMLImage`. * - * @return An NSArray<NSArray<TFLClass *>*> * of classification results. +- * @discussion This method currently supports inference on only following type of images: ++ * @discussion This method currently supports inference on only following type ++ * of images: + * 1. RGB and RGBA images for `GMLImageSourceTypeImage`. + * 2. kCVPixelFormatType_32BGRA for `GMLImageSourceTypePixelBuffer` and +- * `GMLImageSourceTypeSampleBuffer`. If you are using `AVCaptureSession` to setup +- * camera and get the frames for inference, you must request for this format +- * from AVCaptureVideoDataOutput. Otherwise your classification +- * results will be wrong. ++ * `GMLImageSourceTypeSampleBuffer`. If you are using `AVCaptureSession` to ++ * setup camera and get the frames for inference, you must request for this ++ * format from AVCaptureVideoDataOutput. Otherwise your classification results ++ * will be wrong. + * + * @param image An image to be classified, represented as a `GMLImage`. +- * @param roi A CGRect specifying the region of interest within the given `GMLImage`, on which +- * classification should be performed. ++ * @param roi A CGRect specifying the region of interest within the given ++ * `GMLImage`, on which classification should be performed. + * +- * @return A TFLClassificationResult with one set of results per image classifier head. `nil` if +- * there is an error encountered during classification. ++ * @return A TFLClassificationResult with one set of results per image ++ * classifier head. `nil` if there is an error encountered during ++ * classification. */ -- (nullable TFLClassificationResult *)classifyWithGMLImage:(GMLImage *)image - regionOfInterest:(CGRect)roi -- error:(NSError *_Nullable *)error +- error:(NSError **)error +- (nullable TFLClassificationResult*)classifyWithGMLImage:(GMLImage*)image + regionOfInterest:(CGRect)roi -+ error:(NSError* _Nullable*) -+ error - NS_SWIFT_NAME(classify(gmlImage:regionOfInterest:)); ++ error:(NSError**)error + NS_SWIFT_NAME(classify(mlImage:regionOfInterest:)); - (instancetype)init NS_UNAVAILABLE; diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageClassifier.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageClassifier.m -index f1b6c7a7990e4..9259a8dd2defc 100644 +index f8c09527bd902..79ad474054525 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageClassifier.m +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageClassifier.m @@ -40,7 +40,7 @@ return self; } --- (nullable instancetype)initWithModelPath:(NSString *)modelPath { -+- (nullable instancetype)initWithModelPath:(NSString*)modelPath { +-- (instancetype)initWithModelPath:(NSString *)modelPath { ++- (instancetype)initWithModelPath:(NSString*)modelPath { self = [self init]; if (self) { self.baseOptions.modelFile.filePath = modelPath; -@@ -63,11 +63,13 @@ +@@ -63,40 +63,45 @@ return self; } @@ -9739,26 +11017,82 @@ ++ (nullable instancetype)imageClassifierWithOptions: + (TFLImageClassifierOptions*)options + error:(NSError**)error { - TfLiteImageClassifierOptions cOptions = TfLiteImageClassifierOptionsCreate(); -- if (! -- [options.classificationOptions copyToCOptions:&(cOptions.classification_options) error:error]) -+ if (![options.classificationOptions -+ copyToCOptions:&(cOptions.classification_options) -+ error:error]) - return nil; - - [options.baseOptions copyToCOptions:&(cOptions.base_options)]; -@@ -79,7 +81,8 @@ - [options.classificationOptions - deleteCStringArraysOfClassificationOptions:&(cOptions.classification_options)]; - -- if (!imageClassifier || ![TFLCommonUtils checkCError:createClassifierError toError:error]) { -+ if (!imageClassifier || ![TFLCommonUtils checkCError:createClassifierError -+ toError:error]) { - TfLiteSupportErrorDelete(createClassifierError); + if (!options) { +- [TFLCommonUtils createCustomError:error +- withCode:TFLSupportErrorCodeInvalidArgumentError +- description:@"TFLImageClassifierOptions argument cannot be nil."]; ++ [TFLCommonUtils ++ createCustomError:error ++ withCode:TFLSupportErrorCodeInvalidArgumentError ++ description:@"TFLImageClassifierOptions argument cannot be nil."]; return nil; } -@@ -104,7 +107,7 @@ + + TfLiteImageClassifierOptions cOptions = TfLiteImageClassifierOptionsCreate(); + +- if (![options.classificationOptions copyToCOptions:&(cOptions.classification_options) +- error:error]) { +- [options.classificationOptions +- deleteAllocatedMemoryOfClassificationOptions:&(cOptions.classification_options)]; ++ if (![options.classificationOptions ++ copyToCOptions:&(cOptions.classification_options) ++ error:error]) { ++ [options.classificationOptions deleteAllocatedMemoryOfClassificationOptions: ++ &(cOptions.classification_options)]; + return nil; + } + + [options.baseOptions copyToCOptions:&(cOptions.base_options)]; + +- TfLiteSupportError *cCreateClassifierError = NULL; +- TfLiteImageClassifier *cImageClassifier = ++ TfLiteSupportError* cCreateClassifierError = NULL; ++ TfLiteImageClassifier* cImageClassifier = + TfLiteImageClassifierFromOptions(&cOptions, &cCreateClassifierError); + +- [options.classificationOptions +- deleteAllocatedMemoryOfClassificationOptions:&(cOptions.classification_options)]; ++ [options.classificationOptions deleteAllocatedMemoryOfClassificationOptions: ++ &(cOptions.classification_options)]; + +- // Populate iOS error if TfliteSupportError is not null and afterwards delete it. ++ // Populate iOS error if TfliteSupportError is not null and afterwards delete ++ // it. + if (![TFLCommonUtils checkCError:cCreateClassifierError toError:error]) { + TfLiteSupportErrorDelete(cCreateClassifierError); + } + +- // Return nil if classifier evaluates to nil. If an error was generted by the C layer, it has +- // already been populated to an NSError and deleted before returning from the method. ++ // Return nil if classifier evaluates to nil. If an error was generted by the ++ // C layer, it has already been populated to an NSError and deleted before ++ // returning from the method. + if (!cImageClassifier) { + return nil; + } +@@ -104,16 +109,16 @@ + return [[TFLImageClassifier alloc] initWithImageClassifier:cImageClassifier]; + } + +-- (nullable TFLClassificationResult *)classifyWithGMLImage:(GMLImage *)image +- error:(NSError **)error { ++- (nullable TFLClassificationResult*)classifyWithGMLImage:(GMLImage*)image ++ error:(NSError**)error { + return [self classifyWithGMLImage:image + regionOfInterest:CGRectMake(0, 0, image.width, image.height) + error:error]; + } + +-- (nullable TFLClassificationResult *)classifyWithGMLImage:(GMLImage *)image +- regionOfInterest:(CGRect)roi +- error:(NSError **)error { ++- (nullable TFLClassificationResult*)classifyWithGMLImage:(GMLImage*)image ++ regionOfInterest:(CGRect)roi ++ error:(NSError**)error { + if (!image) { + [TFLCommonUtils createCustomError:error + withCode:TFLSupportErrorCodeInvalidArgumentError +@@ -121,7 +126,7 @@ return nil; } @@ -9767,141 +11101,163 @@ if (!cFrameBuffer) { return nil; -@@ -125,7 +128,8 @@ - free(cFrameBuffer); - cFrameBuffer = nil; +@@ -132,7 +137,7 @@ + .width = roi.size.width, + .height = roi.size.height}; -- if (!cClassificationResult || ![TFLCommonUtils checkCError:classifyError toError:error]) { -+ if (!cClassificationResult || ![TFLCommonUtils checkCError:classifyError -+ toError:error]) { +- TfLiteSupportError *classifyError = NULL; ++ TfLiteSupportError* classifyError = NULL; + TfLiteClassificationResult *cClassificationResult = TfLiteImageClassifierClassifyWithRoi( + _imageClassifier, cFrameBuffer, &boundingBox, &classifyError); + +@@ -147,8 +152,9 @@ TfLiteSupportErrorDelete(classifyError); + } + +- // Return nil if C result evaluates to nil. If an error was generted by the C layer, it has +- // already been populated to an NSError and deleted before returning from the method. ++ // Return nil if C result evaluates to nil. If an error was generted by the C ++ // layer, it has already been populated to an NSError and deleted before ++ // returning from the method. + if (!cClassificationResult) { return nil; } diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageSegmenter.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageSegmenter.h -index 48bf54e285930..eb0d565884a22 100644 +index 7b556dcd312e2..234e10d68b319 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageSegmenter.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageSegmenter.h -@@ -20,10 +20,10 @@ +@@ -20,9 +20,10 @@ NS_ASSUME_NONNULL_BEGIN /** -- * Specifies the type of output segmentation mask to be returned as a result -- * of the image segmentation operation. This allows specifying the type of -- * post-processing to perform on the raw model results -- * -+ * Specifies the type of output segmentation mask to be returned as a result -+ * of the image segmentation operation. This allows specifying the type of -+ * post-processing to perform on the raw model results -+ * - * @seealso TfLiteSegmentationResult for more. +- * Specifies the type of the output segmentation mask to be returned as the result +- * of the image segmentation operation. This directs the `TFLImageSegmenter` to +- * choose the type of post-processing to be performed on the raw model results. ++ * Specifies the type of the output segmentation mask to be returned as the ++ * result of the image segmentation operation. This directs the ++ * `TFLImageSegmenter` to choose the type of post-processing to be performed on ++ * the raw model results. */ typedef NS_ENUM(NSUInteger, TFLOutputType) { -@@ -31,13 +31,13 @@ typedef NS_ENUM(NSUInteger, TFLOutputType) { - TFLUnspecifiedOutputType, - - /** -- * Gives a single output mask where each pixel represents the class which -+ * Gives a single output mask where each pixel represents the class which - * the pixel in the original image was predicted to belong to. - */ - TFLCategoryMaskOutputType, - - /** -- * Gives a list of output masks where, for each mask, each pixel represents -+ * Gives a list of output masks where, for each mask, each pixel represents - * the prediction confidence, usually in the [0, 1] range. - */ - TFLConfidenceMasksOutputType, -@@ -53,33 +53,34 @@ typedef NS_ENUM(NSUInteger, TFLOutputType) { + /** Unspecified output type. */ +@@ -52,7 +53,7 @@ NS_SWIFT_NAME(ImageSegmenterOptions) * Base options that is used for creation of any type of task. - * @seealso TFLBaseOptions + * @discussion Please see `TFLBaseOptions` for more details. */ -@property(nonatomic, copy) TFLBaseOptions *baseOptions; +@property(nonatomic, copy) TFLBaseOptions* baseOptions; /** -- * Specifies the type of output segmentation mask to be returned as a result -+ * Specifies the type of output segmentation mask to be returned as a result - * of the image segmentation operation. - * @seealso TFLOutputType + * Specifies the type of output segmentation mask to be returned as a result +@@ -63,24 +64,26 @@ NS_SWIFT_NAME(ImageSegmenterOptions) + /** + * Display names local for display names */ - @property(nonatomic, assign) TFLOutputType outputType; - - /** Display names local for display names*/ -@property(nonatomic, copy) NSString *displayNamesLocale; +@property(nonatomic, copy) NSString* displayNamesLocale; /** - * Initializes TFLImageSegmenterOptions with the model path set to the specified - * path to a model file. -- * @description The external model file, must be a single standalone TFLite -- * file. It could be packed with TFLite Model Metadata[1] and associated files -- * if exist. Fail to provide the necessary metadata and associated files might -- * result in errors. Check the [documentation](https://www.tensorflow.org/lite/convert/metadata) +- * Initializes a new `TFLImageSegmenterOptions` with the absolute path to the model file +- * stored locally on the device, set to the given the model path. ++ * Initializes a new `TFLImageSegmenterOptions` with the absolute path to the ++ * model file stored locally on the device, set to the given the model path. + * . + * @discussion The external model file, must be a single standalone TFLite + * file. It could be packed with TFLite Model Metadata[1] and associated files + * if exist. Fail to provide the necessary metadata and associated files might +- * result in errors. Check the [documentation](https://www.tensorflow.org/lite/convert/metadata) - * for each task about the specific requirement. -- * -+ * @description The external model file, must be a single standalone TFLite -+ * file. It could be packed with TFLite Model Metadata[1] and associated files -+ * if exist. Fail to provide the necessary metadata and associated files might + * result in errors. Check the + * [documentation](https://www.tensorflow.org/lite/convert/metadata) for each + * task about the specific requirement. -+ * - * @param modelPath Path to a TFLite model file. -- * -+ * - * @return An instance of TFLImageSegmenterOptions set to the specified - * modelPath. + * +- * @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device. ++ * @param modelPath An absolute path to a TensorFlow Lite model file stored ++ * locally on the device. + * + * @return An instance of `TFLImageSegmenterOptions` initialized to the given + * model path. */ --- (nullable instancetype)initWithModelPath:(nonnull NSString *)modelPath; -+- (nullable instancetype)initWithModelPath:(nonnull NSString*)modelPath; +-- (instancetype)initWithModelPath:(NSString *)modelPath; ++- (instancetype)initWithModelPath:(NSString*)modelPath; @end -@@ -93,21 +94,23 @@ typedef NS_ENUM(NSUInteger, TFLOutputType) { +@@ -88,17 +91,19 @@ NS_SWIFT_NAME(ImageSegmenter) + @interface TFLImageSegmenter : NSObject + + /** +- * Creates a new instance of `TFLImageSegmenter` from the given `TFLImageSegmenterOptions`. ++ * Creates a new instance of `TFLImageSegmenter` from the given ++ * `TFLImageSegmenterOptions`. * - * @return A TFLImageSegmenter instance. + * @param options The options to use for configuring the `TFLImageSegmenter`. +- * @param error An optional error parameter populated when there is an error in initializing +- * the image segmenter. ++ * @param error An optional error parameter populated when there is an error in ++ * initializing the image segmenter. + * +- * @return A new instance of `TFLImageSegmenter` with the given options. `nil` if there is an error +- * in initializing the image segmenter. ++ * @return A new instance of `TFLImageSegmenter` with the given options. `nil` ++ * if there is an error in initializing the image segmenter. */ -+ (nullable instancetype)imageSegmenterWithOptions:(nonnull TFLImageSegmenterOptions *)options - error:(NSError **)error ++ (nullable instancetype)imageSegmenterWithOptions: + (nonnull TFLImageSegmenterOptions*)options + error:(NSError**)error - NS_SWIFT_NAME(imageSegmenter(options:)); + NS_SWIFT_NAME(segmenter(options:)); + + (instancetype)new NS_UNAVAILABLE; +@@ -106,22 +111,23 @@ NS_SWIFT_NAME(ImageSegmenter) /** -- * Performs image segmentation on a GMLImage input, returns the segmentation -+ * Performs image segmentation on a GMLImage input, returns the segmentation - * results. + * Performs segmentation on the given GMLImage. * - * @param image input to the model. -- * -+ * - * @return Segmentation Result of type TFLSegmentationResult holds the - * segmentation masks returned by the image segmentation task. +- * @discussion This method currently supports segmentation of only the following types of images: ++ * @discussion This method currently supports segmentation of only the following ++ * types of images: + * 1. RGB and RGBA images for `GMLImageSourceTypeImage`. + * 2. kCVPixelFormatType_32BGRA for `GMLImageSourceTypePixelBuffer` and +- * `GMLImageSourceTypeSampleBuffer`. If you are using `AVCaptureSession` to setup +- * camera and get the frames for inference, you must request for this format +- * from AVCaptureVideoDataOutput. Otherwise your segmentation +- * results will be wrong. ++ * `GMLImageSourceTypeSampleBuffer`. If you are using `AVCaptureSession` to ++ * setup camera and get the frames for inference, you must request for this ++ * format from AVCaptureVideoDataOutput. Otherwise your segmentation results ++ * will be wrong. + * + * @param image An image to be segmented, represented as a `GMLImage`. + * +- * @return A TFLSegmentationResult that holds the segmentation masks returned by the image +- * segmentation task. `nil` if there is an error encountered during segmentation. Please see +- * `TFLSegmentationResult` for more details. ++ * @return A TFLSegmentationResult that holds the segmentation masks returned by ++ * the image segmentation task. `nil` if there is an error encountered during ++ * segmentation. Please see `TFLSegmentationResult` for more details. */ -- (nullable TFLSegmentationResult *)segmentWithGMLImage:(GMLImage *)image -- error:(NSError *_Nullable *)error +- error:(NSError **)error +- (nullable TFLSegmentationResult*)segmentWithGMLImage:(GMLImage*)image -+ error: -+ (NSError* _Nullable*)error - NS_SWIFT_NAME(segment(gmlImage:)); ++ error:(NSError**)error + NS_SWIFT_NAME(segment(mlImage:)); - (instancetype)init NS_UNAVAILABLE; diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageSegmenter.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageSegmenter.m -index 2655673f68f58..8f22beebdb414 100644 +index 70068bfdd645a..7b7f3211df952 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageSegmenter.m +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageSegmenter.m -@@ -33,7 +33,7 @@ +@@ -35,7 +35,7 @@ return self; } --- (nullable instancetype)initWithModelPath:(nonnull NSString *)modelPath { -+- (nullable instancetype)initWithModelPath:(nonnull NSString*)modelPath { +-- (instancetype)initWithModelPath:(NSString *)modelPath { ++- (instancetype)initWithModelPath:(NSString*)modelPath { self = [self init]; if (self) { self.baseOptions.modelFile.filePath = modelPath; -@@ -45,14 +45,14 @@ +@@ -47,14 +47,14 @@ @implementation TFLImageSegmenter { /** ImageSegmenter backed by C API */ @@ -9918,7 +11274,7 @@ self = [super init]; if (self) { _imageSegmenter = imageSegmenter; -@@ -60,17 +60,19 @@ +@@ -62,8 +62,9 @@ return self; } @@ -9930,50 +11286,87 @@ TfLiteImageSegmenterOptions cOptions = TfLiteImageSegmenterOptionsCreate(); [options.baseOptions copyToCOptions:&(cOptions.base_options)]; +@@ -71,20 +72,22 @@ -- TfLiteSupportError *createImageSegmenterError = nil; -- TfLiteImageSegmenter *imageSegmenter = -+ TfLiteSupportError* createImageSegmenterError = nil; -+ TfLiteImageSegmenter* imageSegmenter = - TfLiteImageSegmenterFromOptions(&cOptions, &createImageSegmenterError); + if (options.displayNamesLocale) { + if (options.displayNamesLocale.UTF8String) { +- cOptions.display_names_locale = strdup(options.displayNamesLocale.UTF8String); ++ cOptions.display_names_locale = ++ strdup(options.displayNamesLocale.UTF8String); + if (!cOptions.display_names_locale) { + exit(-1); // Memory Allocation Failed. + } + } else { +- [TFLCommonUtils createCustomError:error +- withCode:TFLSupportErrorCodeInvalidArgumentError +- description:@"Could not convert (NSString *) to (char *)."]; ++ [TFLCommonUtils ++ createCustomError:error ++ withCode:TFLSupportErrorCodeInvalidArgumentError ++ description:@"Could not convert (NSString *) to (char *)."]; + return nil; + } + } -- if (!imageSegmenter || ![TFLCommonUtils checkCError:createImageSegmenterError toError:error]) { -+ if (!imageSegmenter || ![TFLCommonUtils checkCError:createImageSegmenterError -+ toError:error]) { - TfLiteSupportErrorDelete(createImageSegmenterError); +- TfLiteSupportError *cCreateImageSegmenterError = nil; +- TfLiteImageSegmenter *cImageSegmenter = ++ TfLiteSupportError* cCreateImageSegmenterError = nil; ++ TfLiteImageSegmenter* cImageSegmenter = + TfLiteImageSegmenterFromOptions(&cOptions, &cCreateImageSegmenterError); + + // Freeing memory of allocated string. +@@ -94,16 +97,17 @@ + TfLiteSupportErrorDelete(cCreateImageSegmenterError); + } + +- // Return nil if C object detector evaluates to nil. If an error was generted by the C layer, it +- // has already been populated to an NSError and deleted before returning from the method. ++ // Return nil if C object detector evaluates to nil. If an error was generted ++ // by the C layer, it has already been populated to an NSError and deleted ++ // before returning from the method. + if (!cImageSegmenter) { return nil; } -@@ -78,16 +80,17 @@ - return [[TFLImageSegmenter alloc] initWithImageSegmenter:imageSegmenter]; + return [[TFLImageSegmenter alloc] initWithImageSegmenter:cImageSegmenter]; } -- (nullable TFLSegmentationResult *)segmentWithGMLImage:(GMLImage *)image -- error:(NSError *_Nullable *)error { -- TfLiteFrameBuffer *cFrameBuffer = [image cFrameBufferWithError:error]; +- error:(NSError **)error { +- (nullable TFLSegmentationResult*)segmentWithGMLImage:(GMLImage*)image -+ error:(NSError* _Nullable*) -+ error { ++ error:(NSError**)error { + if (!image) { + [TFLCommonUtils createCustomError:error + withCode:TFLSupportErrorCodeInvalidArgumentError +@@ -111,15 +115,15 @@ + return nil; + } + +- TfLiteFrameBuffer *cFrameBuffer = [image cFrameBufferWithError:error]; + TfLiteFrameBuffer* cFrameBuffer = [image cFrameBufferWithError:error]; if (!cFrameBuffer) { return nil; } -- TfLiteSupportError *segmentError = nil; +- TfLiteSupportError *cSegmentError = nil; - TfLiteSegmentationResult *cSegmentationResult = -+ TfLiteSupportError* segmentError = nil; -+ TfLiteSegmentationResult* cSegmentationResult = - TfLiteImageSegmenterSegment(_imageSegmenter, cFrameBuffer, &segmentError); +- TfLiteImageSegmenterSegment(_imageSegmenter, cFrameBuffer, &cSegmentError); ++ TfLiteSupportError* cSegmentError = nil; ++ TfLiteSegmentationResult* cSegmentationResult = TfLiteImageSegmenterSegment( ++ _imageSegmenter, cFrameBuffer, &cSegmentError); free(cFrameBuffer->buffer); -@@ -96,12 +99,13 @@ - free(cFrameBuffer); - cFrameBuffer = nil; + cFrameBuffer->buffer = nil; +@@ -132,13 +136,14 @@ + TfLiteSupportErrorDelete(cSegmentError); + } -- if (!cSegmentationResult || ![TFLCommonUtils checkCError:segmentError toError:error]) { -+ if (!cSegmentationResult || ![TFLCommonUtils checkCError:segmentError -+ toError:error]) { - TfLiteSupportErrorDelete(segmentError); +- // Return nil if C result evaluates to nil. If an error was generted by the C layer, it has +- // already been populated to an NSError and deleted before returning from the method. ++ // Return nil if C result evaluates to nil. If an error was generted by the C ++ // layer, it has already been populated to an NSError and deleted before ++ // returning from the method. + if (!cSegmentationResult) { return nil; } @@ -9983,121 +11376,228 @@ TfLiteSegmentationResultDelete(cSegmentationResult); diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLObjectDetector.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLObjectDetector.h -index 64271f31b6882..38b3b0ab18ebe 100644 +index 5e3a0e7186cfe..db76c90cc6868 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLObjectDetector.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLObjectDetector.h -@@ -29,27 +29,29 @@ NS_ASSUME_NONNULL_BEGIN +@@ -30,28 +30,31 @@ NS_SWIFT_NAME(ObjectDetectorOptions) * Base options that is used for creation of any type of task. - * @seealso TFLBaseOptions + * @discussion Please see `TFLBaseOptions` for more details. */ -@property(nonatomic, copy) TFLBaseOptions *baseOptions; +@property(nonatomic, copy) TFLBaseOptions* baseOptions; /** * Options that configure the display and filtering of results. - * @seealso TFLClassificationOptions + * @discussion Please see `TFLClassificationOptions` for more details. */ -@property(nonatomic, copy) TFLClassificationOptions *classificationOptions; +@property(nonatomic, copy) TFLClassificationOptions* classificationOptions; /** -- * Initializes TFLObjectDetectorOptions with the model path set to the specified path to a model -- * file. -- * @description The external model file, must be a single standalone TFLite file. It could be packed +- * Initializes a new `TFLObjectDetectorOptions` with the absolute path to the model file +- * stored locally on the device, set to the given the model path. ++ * Initializes a new `TFLObjectDetectorOptions` with the absolute path to the ++ * model file stored locally on the device, set to the given the model path. + * +- * @discussion The external model file, must be a single standalone TFLite file. It could be packed - * with TFLite Model Metadata[1] and associated files if exist. Fail to provide the necessary - * metadata and associated files might result in errors. Check the [documentation] - * (https://www.tensorflow.org/lite/convert/metadata) for each task about the specific requirement. -+ * Initializes TFLObjectDetectorOptions with the model path set to the specified -+ * path to a model file. -+ * @description The external model file, must be a single standalone TFLite -+ * file. It could be packed with TFLite Model Metadata[1] and associated files -+ * if exist. Fail to provide the necessary metadata and associated files might ++ * @discussion The external model file, must be a single standalone TFLite file. ++ * It could be packed with TFLite Model Metadata[1] and associated files if ++ * exist. Fail to provide the necessary metadata and associated files might + * result in errors. Check the [documentation] + * (https://www.tensorflow.org/lite/convert/metadata) for each task about the + * specific requirement. * - * @param modelPath Path to a TFLite model file. - * @return An instance of TFLObjectDetectorOptions set to the specified - * modelPath. +- * @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device. ++ * @param modelPath An absolute path to a TensorFlow Lite model file stored ++ * locally on the device. + * @return An instance of `TFLObjectDetectorOptions` initialized to the given + * model path. */ --- (nullable instancetype)initWithModelPath:(nonnull NSString *)modelPath; -+- (nullable instancetype)initWithModelPath:(nonnull NSString*)modelPath; +-- (instancetype)initWithModelPath:(NSString *)modelPath; ++- (instancetype)initWithModelPath:(NSString*)modelPath; @end -@@ -63,20 +65,22 @@ NS_ASSUME_NONNULL_BEGIN - * - * @return A TFLObjectDetector instance. - */ --+ (nullable instancetype)objectDetectorWithOptions:(nonnull TFLObjectDetectorOptions *)options -- error:(NSError **)error -++ (nullable instancetype)objectDetectorWithOptions: -+ (nonnull TFLObjectDetectorOptions*)options -+ error:(NSError**)error - NS_SWIFT_NAME(objectDetector(options:)); +@@ -59,40 +62,43 @@ NS_SWIFT_NAME(ObjectDetector) + @interface TFLObjectDetector : NSObject /** -- * Performs object detection on a GMLImage input, returns the detected objects in the image. -+ * Performs object detection on a GMLImage input, returns the detected objects -+ * in the image. +- * Creates a new instance of `TFLObjectDetector` from the given `TFLObjectDetectorOptions`. ++ * Creates a new instance of `TFLObjectDetector` from the given ++ * `TFLObjectDetectorOptions`. * - * @param image input to the model. - * @return Detection Result of type TFLDetectionResult an array of -- * detected objeects where each detected object has a bounding box and an array of TFLCategory -- * holding the predicted classes for the detected object. -+ * detected objeects where each detected object has a bounding box and an array -+ * of TFLCategory holding the predicted classes for the detected object. + * @param options The options to use for configuring the `TFLObjectDetector`. +- * @param error An optional error parameter populated when there is an error in initializing +- * the object detector. ++ * @param error An optional error parameter populated when there is an error in ++ * initializing the object detector. + * +- * @return A new instance of `TFLObjectDetector` with the given options. `nil` if there is an error +- * in initializing the object detector. ++ * @return A new instance of `TFLObjectDetector` with the given options. `nil` ++ * if there is an error in initializing the object detector. + */ +-+ (nullable instancetype)objectDetectorWithOptions:(TFLObjectDetectorOptions *)options +- error:(NSError **)error +++ (nullable instancetype)objectDetectorWithOptions: ++ (TFLObjectDetectorOptions*)options ++ error:(NSError**)error + NS_SWIFT_NAME(detector(options:)); + + + (instancetype)new NS_UNAVAILABLE; + + /** + * Performs object detection on the given GMLImage. +- * @discussion This method currently supports object detection on only the following types of +- * images: ++ * @discussion This method currently supports object detection on only the ++ * following types of images: + * 1. RGB and RGBA images for `GMLImageSourceTypeImage`. + * 2. `kCVPixelFormatType_32BGRA` for `GMLImageSourceTypePixelBuffer` and +- * `GMLImageSourceTypeSampleBuffer`. If you are using `AVCaptureSession` to setup +- * camera and get the frames for inference, you must request for this format +- * from AVCaptureVideoDataOutput. Otherwise your object detection +- * results will be wrong. ++ * `GMLImageSourceTypeSampleBuffer`. If you are using `AVCaptureSession` to ++ * setup camera and get the frames for inference, you must request for this ++ * format from AVCaptureVideoDataOutput. Otherwise your object detection results ++ * will be wrong. + * +- * @param image An image on which object detection is to be performed, represented as a `GMLImage`. ++ * @param image An image on which object detection is to be performed, ++ * represented as a `GMLImage`. + * +- * @return A `TFLDetectionResult` holding an array of TFLDetection objects, each having a bounding +- * box specifying the region the were detected in and an array of predicted classes. Please see +- * `TFLDetectionResult` for more details. ++ * @return A `TFLDetectionResult` holding an array of TFLDetection objects, each ++ * having a bounding box specifying the region the were detected in and an array ++ * of predicted classes. Please see `TFLDetectionResult` for more details. */ -- (nullable TFLDetectionResult *)detectWithGMLImage:(GMLImage *)image -- error:(NSError *_Nullable *)error +- error:(NSError **)error +- (nullable TFLDetectionResult*)detectWithGMLImage:(GMLImage*)image -+ error:(NSError* _Nullable*)error - NS_SWIFT_NAME(detect(gmlImage:)); ++ error:(NSError**)error + NS_SWIFT_NAME(detect(mlImage:)); - (instancetype)init NS_UNAVAILABLE; diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLObjectDetector.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLObjectDetector.m -index 2dc3358955b44..1e47e254dc204 100644 +index 31cb241a2a448..def2e5b0b4877 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLObjectDetector.m +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLObjectDetector.m -@@ -65,8 +65,9 @@ - + (nullable instancetype)objectDetectorWithOptions:(nonnull TFLObjectDetectorOptions *)options - error:(NSError **)error { - TfLiteObjectDetectorOptions cOptions = TfLiteObjectDetectorOptionsCreate(); -- if (! -- [options.classificationOptions copyToCOptions:&(cOptions.classification_options) error:error]) -+ if (![options.classificationOptions -+ copyToCOptions:&(cOptions.classification_options) -+ error:error]) - return nil; +@@ -40,7 +40,7 @@ + return self; + } - [options.baseOptions copyToCOptions:&(cOptions.base_options)]; -@@ -78,7 +79,8 @@ - [options.classificationOptions - deleteCStringArraysOfClassificationOptions:&(cOptions.classification_options)]; +-- (instancetype)initWithModelPath:(NSString *)modelPath { ++- (instancetype)initWithModelPath:(NSString*)modelPath { + self = [self init]; + if (self) { + self.baseOptions.modelFile.filePath = modelPath; +@@ -63,40 +63,45 @@ + return self; + } -- if (!objectDetector || ![TFLCommonUtils checkCError:createObjectDetectorError toError:error]) { -+ if (!objectDetector || ![TFLCommonUtils checkCError:createObjectDetectorError -+ toError:error]) { - TfLiteSupportErrorDelete(createObjectDetectorError); +-+ (nullable instancetype)objectDetectorWithOptions:(TFLObjectDetectorOptions *)options +- error:(NSError **)error { +++ (nullable instancetype)objectDetectorWithOptions: ++ (TFLObjectDetectorOptions*)options ++ error:(NSError**)error { + if (!options) { +- [TFLCommonUtils createCustomError:error +- withCode:TFLSupportErrorCodeInvalidArgumentError +- description:@"TFLObjectDetectorOptions argument cannot be nil."]; ++ [TFLCommonUtils ++ createCustomError:error ++ withCode:TFLSupportErrorCodeInvalidArgumentError ++ description:@"TFLObjectDetectorOptions argument cannot be nil."]; return nil; } -@@ -88,7 +90,7 @@ - - (nullable TFLDetectionResult *)detectWithGMLImage:(GMLImage *)image - error:(NSError *_Nullable *)error { + TfLiteObjectDetectorOptions cOptions = TfLiteObjectDetectorOptionsCreate(); +- if (![options.classificationOptions copyToCOptions:&(cOptions.classification_options) +- error:error]) { ++ if (![options.classificationOptions ++ copyToCOptions:&(cOptions.classification_options) ++ error:error]) { + // Deallocating any allocated memory on failure. +- [options.classificationOptions +- deleteAllocatedMemoryOfClassificationOptions:&(cOptions.classification_options)]; ++ [options.classificationOptions deleteAllocatedMemoryOfClassificationOptions: ++ &(cOptions.classification_options)]; + return nil; + } + + [options.baseOptions copyToCOptions:&(cOptions.base_options)]; + +- TfLiteSupportError *cCreateObjectDetectorError = nil; +- TfLiteObjectDetector *cObjectDetector = ++ TfLiteSupportError* cCreateObjectDetectorError = nil; ++ TfLiteObjectDetector* cObjectDetector = + TfLiteObjectDetectorFromOptions(&cOptions, &cCreateObjectDetectorError); + +- [options.classificationOptions +- deleteAllocatedMemoryOfClassificationOptions:&(cOptions.classification_options)]; ++ [options.classificationOptions deleteAllocatedMemoryOfClassificationOptions: ++ &(cOptions.classification_options)]; + +- // Populate iOS error if TfliteSupportError is not null and afterwards delete it. ++ // Populate iOS error if TfliteSupportError is not null and afterwards delete ++ // it. + if (![TFLCommonUtils checkCError:cCreateObjectDetectorError toError:error]) { + TfLiteSupportErrorDelete(cCreateObjectDetectorError); + } + +- // Return nil if C object detector evaluates to nil. If an error was generted by the C layer, it +- // has already been populated to an NSError and deleted before returning from the method. ++ // Return nil if C object detector evaluates to nil. If an error was generted ++ // by the C layer, it has already been populated to an NSError and deleted ++ // before returning from the method. + if (!cObjectDetector) { + return nil; + } +@@ -104,8 +109,8 @@ + return [[TFLObjectDetector alloc] initWithObjectDetector:cObjectDetector]; + } + +-- (nullable TFLDetectionResult *)detectWithGMLImage:(GMLImage *)image +- error:(NSError **)error { ++- (nullable TFLDetectionResult*)detectWithGMLImage:(GMLImage*)image ++ error:(NSError**)error { + if (!image) { + [TFLCommonUtils createCustomError:error + withCode:TFLSupportErrorCodeInvalidArgumentError +@@ -113,14 +118,14 @@ + return nil; + } + - TfLiteFrameBuffer *cFrameBuffer = [image cFrameBufferWithError:error]; + TfLiteFrameBuffer* cFrameBuffer = [image cFrameBufferWithError:error]; if (!cFrameBuffer) { return nil; -@@ -104,7 +106,8 @@ - free(cFrameBuffer); - cFrameBuffer = nil; + } -- if (!cDetectionResult || ![TFLCommonUtils checkCError:detectError toError:error]) { -+ if (!cDetectionResult || ![TFLCommonUtils checkCError:detectError -+ toError:error]) { - TfLiteSupportErrorDelete(detectError); +- TfLiteSupportError *cDetectError = nil; +- TfLiteDetectionResult *cDetectionResult = ++ TfLiteSupportError* cDetectError = nil; ++ TfLiteDetectionResult* cDetectionResult = + TfLiteObjectDetectorDetect(_objectDetector, cFrameBuffer, &cDetectError); + + free(cFrameBuffer->buffer); +@@ -134,8 +139,9 @@ + TfLiteSupportErrorDelete(cDetectError); + } + +- // Return nil if C result evaluates to nil. If an error was generted by the C layer, it has +- // already been populated to an NSError and deleted before returning from the method. ++ // Return nil if C result evaluates to nil. If an error was generted by the C ++ // layer, it has already been populated to an NSError and deleted before ++ // returning from the method. + if (!cDetectionResult) { return nil; } diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/utils/sources/GMLImage+Utils.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/utils/sources/GMLImage+Utils.h @@ -10127,10 +11627,10 @@ @end diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/utils/sources/GMLImage+Utils.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/utils/sources/GMLImage+Utils.m -index 1992fe306afc4..4339ade38a01a 100644 +index d1ab5105448fe..532f75ef25a6c 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/utils/sources/GMLImage+Utils.m +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/utils/sources/GMLImage+Utils.m -@@ -25,33 +25,37 @@ +@@ -25,35 +25,38 @@ @interface TFLCVPixelBufferUtils : NSObject @@ -10139,26 +11639,19 @@ - frameBufferFormat:(enum TfLiteFrameBufferFormat)frameBufferFormat - buffer:(uint8_t *)buffer - error:(NSError **)error; -- --+ (uint8_t *_Nullable)convertBGRAtoRGBforPixelBufferBaseAddress:(CVPixelBufferRef)pixelBuffer -- error:(NSError **)error; -- --+ (TfLiteFrameBuffer *)cFramebufferFromCVPixelBuffer:(CVPixelBufferRef)pixelBuffer -- error:(NSError **)error; ++ (TfLiteFrameBuffer*)cFrameBufferWithWidth:(int)width + height:(int)height + frameBufferFormat: + (enum TfLiteFrameBufferFormat)frameBufferFormat + buffer:(uint8_t*)buffer + error:(NSError**)error; -+ -++ (uint8_t* _Nullable) -+ convertBGRAtoRGBforPixelBufferBaseAddress:(CVPixelBufferRef)pixelBuffer -+ error:(NSError**)error; -+ + +-+ (TfLiteFrameBuffer *)cFramebufferFromCVPixelBuffer:(CVPixelBufferRef)pixelBuffer +- error:(NSError **)error; ++ (TfLiteFrameBuffer*)cFramebufferFromCVPixelBuffer: + (CVPixelBufferRef)pixelBuffer + error:(NSError**)error; + @end @interface UIImage (RawPixelDataUtils) @@ -10174,23 +11667,107 @@ - frameBufferFormat:(enum TfLiteFrameBufferFormat)frameBufferFormat - buffer:(uint8_t *)buffer - error:(NSError **)error { -- TfLiteFrameBuffer *cFrameBuffer = [TFLCommonUtils mallocWithSize:sizeof(TfLiteFrameBuffer) -- error:error]; ++ (TfLiteFrameBuffer*)cFrameBufferWithWidth:(int)width + height:(int)height + frameBufferFormat: + (enum TfLiteFrameBufferFormat)frameBufferFormat + buffer:(uint8_t*)buffer + error:(NSError**)error { + if (!buffer) { + return NULL; + } + +- TfLiteFrameBuffer *cFrameBuffer = [TFLCommonUtils mallocWithSize:sizeof(TfLiteFrameBuffer) +- error:error]; + TfLiteFrameBuffer* cFrameBuffer = + [TFLCommonUtils mallocWithSize:sizeof(TfLiteFrameBuffer) error:error]; if (cFrameBuffer) { cFrameBuffer->dimension.width = width; -@@ -63,9 +67,10 @@ +@@ -65,17 +68,18 @@ return cFrameBuffer; } +-+ (uint8_t *)createRGBImageDatafromImageData:(uint8_t *)data +- withWidth:(size_t)width +- height:(size_t)height +- stride:(size_t)stride +- pixelBufferFormat:(OSType)pixelBufferFormatType +- error:(NSError **)error { +++ (uint8_t*)createRGBImageDatafromImageData:(uint8_t*)data ++ withWidth:(size_t)width ++ height:(size_t)height ++ stride:(size_t)stride ++ pixelBufferFormat:(OSType)pixelBufferFormatType ++ error:(NSError**)error { + NSInteger destinationChannelCount = 3; + size_t destinationBytesPerRow = width * destinationChannelCount; + +- uint8_t *destPixelBufferAddress = +- [TFLCommonUtils mallocWithSize:sizeof(uint8_t) * height * destinationBytesPerRow error:error]; ++ uint8_t* destPixelBufferAddress = [TFLCommonUtils ++ mallocWithSize:sizeof(uint8_t) * height * destinationBytesPerRow ++ error:error]; + + if (!destPixelBufferAddress) { + return NULL; +@@ -95,19 +99,23 @@ + + switch (pixelBufferFormatType) { + case kCVPixelFormatType_32RGBA: { +- convertError = vImageConvert_RGBA8888toRGB888(&srcBuffer, &destBuffer, kvImageNoFlags); ++ convertError = vImageConvert_RGBA8888toRGB888(&srcBuffer, &destBuffer, ++ kvImageNoFlags); + break; + } + case kCVPixelFormatType_32BGRA: { +- convertError = vImageConvert_BGRA8888toRGB888(&srcBuffer, &destBuffer, kvImageNoFlags); ++ convertError = vImageConvert_BGRA8888toRGB888(&srcBuffer, &destBuffer, ++ kvImageNoFlags); + break; + } + default: { +- [TFLCommonUtils createCustomError:error +- withCode:TFLSupportErrorCodeInvalidArgumentError +- description:@"Invalid source pixel buffer format. Expecting one of " +- @"kCVPixelFormatType_32RGBA, kCVPixelFormatType_32BGRA, " +- @"kCVPixelFormatType_32ARGB"]; ++ [TFLCommonUtils ++ createCustomError:error ++ withCode:TFLSupportErrorCodeInvalidArgumentError ++ description: ++ @"Invalid source pixel buffer format. Expecting one of " ++ @"kCVPixelFormatType_32RGBA, kCVPixelFormatType_32BGRA, " ++ @"kCVPixelFormatType_32ARGB"]; + + free(destPixelBufferAddress); + return NULL; +@@ -126,16 +134,17 @@ + return destPixelBufferAddress; + } + +-+ (uint8_t *)createRGBImageDatafromCVPixelBuffer:(CVPixelBufferRef)pixelBuffer +- error:(NSError **)error { +++ (uint8_t*)createRGBImageDatafromCVPixelBuffer:(CVPixelBufferRef)pixelBuffer ++ error:(NSError**)error { + CVPixelBufferLockBaseAddress(pixelBuffer, 0); + +- uint8_t *rgbData = [TFLCVPixelBufferUtils ++ uint8_t* rgbData = [TFLCVPixelBufferUtils + createRGBImageDatafromImageData:CVPixelBufferGetBaseAddress(pixelBuffer) + withWidth:CVPixelBufferGetWidth(pixelBuffer) + height:CVPixelBufferGetHeight(pixelBuffer) + stride:CVPixelBufferGetBytesPerRow(pixelBuffer) +- pixelBufferFormat:CVPixelBufferGetPixelFormatType(pixelBuffer) ++ pixelBufferFormat:CVPixelBufferGetPixelFormatType( ++ pixelBuffer) + error:error]; + + CVPixelBufferUnlockBaseAddress(pixelBuffer, 0); +@@ -143,9 +152,10 @@ + return rgbData; + } + -+ (TfLiteFrameBuffer *)cFramebufferFromCVPixelBuffer:(CVPixelBufferRef)pixelBuffer - error:(NSError **)error { - uint8_t *buffer = NULL; @@ -10200,116 +11777,32 @@ + uint8_t* buffer = NULL; enum TfLiteFrameBufferFormat cPixelFormat = kRGB; - CVPixelBufferLockBaseAddress(pixelBuffer, 0); -@@ -74,26 +79,33 @@ - switch (pixelBufferFormat) { - case kCVPixelFormatType_24RGB: { - cPixelFormat = kRGB; -- buffer = [TFLCVPixelBufferUtils copyPixelBufferDataForInference:pixelBuffer error:error]; -+ buffer = -+ [TFLCVPixelBufferUtils copyPixelBufferDataForInference:pixelBuffer -+ error:error]; - break; - } - case kCVPixelFormatType_32RGBA: { - cPixelFormat = kRGBA; -- buffer = [TFLCVPixelBufferUtils copyPixelBufferDataForInference:pixelBuffer error:error]; -+ buffer = -+ [TFLCVPixelBufferUtils copyPixelBufferDataForInference:pixelBuffer -+ error:error]; - break; - } + OSType pixelBufferFormat = CVPixelBufferGetPixelFormatType(pixelBuffer); +@@ -154,14 +164,18 @@ case kCVPixelFormatType_32BGRA: { cPixelFormat = kRGB; -- buffer = [TFLCVPixelBufferUtils convertBGRAtoRGBforPixelBufferBaseAddress:pixelBuffer -- error:error]; -+ buffer = [TFLCVPixelBufferUtils -+ convertBGRAtoRGBforPixelBufferBaseAddress:pixelBuffer -+ error:error]; + +- buffer = [TFLCVPixelBufferUtils createRGBImageDatafromCVPixelBuffer:pixelBuffer error:error]; ++ buffer = ++ [TFLCVPixelBufferUtils createRGBImageDatafromCVPixelBuffer:pixelBuffer ++ error:error]; break; } default: { - [TFLCommonUtils createCustomError:error - withCode:TFLSupportErrorCodeInvalidArgumentError - description:@"Unsupported pixel format for CVPixelBuffer. Supported " -- @"pixel format types are kCVPixelFormatType_32RGBA, " -- @"kCVPixelFormatType_32BGRA, kCVPixelFormatType_24RGB"]; +- @"pixel format types are kCVPixelFormatType_32BGRA"]; + [TFLCommonUtils + createCustomError:error + withCode:TFLSupportErrorCodeInvalidArgumentError + description: + @"Unsupported pixel format for CVPixelBuffer. Supported " -+ @"pixel format types are kCVPixelFormatType_32RGBA, " -+ @"kCVPixelFormatType_32BGRA, kCVPixelFormatType_24RGB"]; ++ @"pixel format types are kCVPixelFormatType_32BGRA"]; } } -@@ -110,18 +122,22 @@ - error:error]; - } - --+ (UInt8 *)copyPixelBufferDataForInference:(CVPixelBufferRef)pixelBuffer error:(NSError **)error { -++ (UInt8*)copyPixelBufferDataForInference:(CVPixelBufferRef)pixelBuffer -+ error:(NSError**)error { - size_t height = CVPixelBufferGetHeight(pixelBuffer); - size_t stride = CVPixelBufferGetBytesPerRow(pixelBuffer); -- UInt8 *buffer = [TFLCommonUtils mallocWithSize:sizeof(UInt8) * height * stride error:error]; -+ UInt8* buffer = [TFLCommonUtils mallocWithSize:sizeof(UInt8) * height * stride -+ error:error]; - -- if (buffer) memcpy(buffer, CVPixelBufferGetBaseAddress(pixelBuffer), height * stride); -+ if (buffer) -+ memcpy(buffer, CVPixelBufferGetBaseAddress(pixelBuffer), height * stride); - - return buffer; - } - --+ (uint8_t *)convertBGRAtoRGBforPixelBufferBaseAddress:(CVPixelBufferRef)pixelBuffer -- error:(NSError **)error { -++ (uint8_t*)convertBGRAtoRGBforPixelBufferBaseAddress: -+ (CVPixelBufferRef)pixelBuffer -+ error:(NSError**)error { - size_t width = CVPixelBufferGetWidth(pixelBuffer); - size_t height = CVPixelBufferGetHeight(pixelBuffer); - size_t stride = CVPixelBufferGetBytesPerRow(pixelBuffer); -@@ -129,17 +145,21 @@ - int destinationChannelCount = 3; - size_t destinationBytesPerRow = destinationChannelCount * width; - -- uint8_t *pixelBufferBaseAddress = (uint8_t *)CVPixelBufferGetBaseAddress(pixelBuffer); -+ uint8_t* pixelBufferBaseAddress = -+ (uint8_t*)CVPixelBufferGetBaseAddress(pixelBuffer); - -- uint8_t *destPixelBufferAddress = -- [TFLCommonUtils mallocWithSize:sizeof(uint8_t) * height * destinationBytesPerRow error:error]; -+ uint8_t* destPixelBufferAddress = [TFLCommonUtils -+ mallocWithSize:sizeof(uint8_t) * height * destinationBytesPerRow -+ error:error]; - - if (!destPixelBufferAddress) { - return NULL; - } - -- vImage_Buffer srcBuffer = { -- .data = pixelBufferBaseAddress, .height = height, .width = width, .rowBytes = stride}; -+ vImage_Buffer srcBuffer = {.data = pixelBufferBaseAddress, -+ .height = height, -+ .width = width, -+ .rowBytes = stride}; - - vImage_Buffer destBuffer = {.data = destPixelBufferAddress, - .height = height, -@@ -147,7 +167,8 @@ - .rowBytes = destinationBytesPerRow}; - - vImage_Error convertError = kvImageNoError; -- convertError = vImageConvert_BGRA8888toRGB888(&srcBuffer, &destBuffer, kvImageNoFlags); -+ convertError = -+ vImageConvert_BGRA8888toRGB888(&srcBuffer, &destBuffer, kvImageNoFlags); - - if (convertError != kvImageNoError) { - [TFLCommonUtils createCustomError:error -@@ -163,8 +184,8 @@ +@@ -176,8 +190,8 @@ @implementation UIImage (RawPixelDataUtils) @@ -10320,7 +11813,7 @@ if (self.CGImage) { frameBuffer = [self frameBufferFromCGImage:self.CGImage error:error]; -@@ -189,47 +210,50 @@ +@@ -202,59 +216,65 @@ } CGDataProviderRef imageDataProvider = CGImageGetDataProvider(cgImage); @@ -10353,8 +11846,8 @@ -+ (UInt8 *_Nullable)pixelDataFromCGImage:(CGImageRef)cgImage error:(NSError **)error { ++ (UInt8* _Nullable)pixelDataFromCGImage:(CGImageRef)cgImage + error:(NSError**)error { - long width = CGImageGetWidth(cgImage); - long height = CGImageGetHeight(cgImage); + size_t width = CGImageGetWidth(cgImage); + size_t height = CGImageGetHeight(cgImage); NSInteger bitsPerComponent = 8; NSInteger channelCount = 4; @@ -10362,6 +11855,7 @@ + UInt8* buffer_to_return = NULL; CGColorSpaceRef colorSpace = CGColorSpaceCreateDeviceRGB(); + size_t bytesPerRow = channelCount * width; // iOS infers bytesPerRow if it is set to 0. - // See https://developer.apple.com/documentation/coregraphics/1455939-cgbitmapcontextcreate @@ -10369,52 +11863,47 @@ + // https://developer.apple.com/documentation/coregraphics/1455939-cgbitmapcontextcreate // But for segmentation test image, this was not the case. // Hence setting it to the value of channelCount*width. -- CGContextRef context = -- CGBitmapContextCreate(nil, width, height, bitsPerComponent, channelCount * width, colorSpace, -- kCGImageAlphaNoneSkipLast | kCGBitmapByteOrder32Big); -+ CGContextRef context = CGBitmapContextCreate( -+ nil, width, height, bitsPerComponent, channelCount * width, colorSpace, -+ kCGImageAlphaNoneSkipLast | kCGBitmapByteOrder32Big); + // kCGImageAlphaNoneSkipLast specifies that Alpha will always be next to B. + // kCGBitmapByteOrder32Big specifies that R will be stored before B. + // In combination they signify a pixelFormat of kCVPixelFormatType32RGBA. +- CGBitmapInfo bitMapinfoFor32RGBA = kCGImageAlphaNoneSkipLast | kCGBitmapByteOrder32Big; +- CGContextRef context = CGBitmapContextCreate(nil, width, height, bitsPerComponent, bytesPerRow, +- colorSpace, bitMapinfoFor32RGBA); ++ CGBitmapInfo bitMapinfoFor32RGBA = ++ kCGImageAlphaNoneSkipLast | kCGBitmapByteOrder32Big; ++ CGContextRef context = ++ CGBitmapContextCreate(nil, width, height, bitsPerComponent, bytesPerRow, ++ colorSpace, bitMapinfoFor32RGBA); if (context) { CGContextDrawImage(context, CGRectMake(0, 0, width, height), cgImage); -- buffer_to_return = -- [UIImage populateRGBBufferFromSourceRGBABuffer:CGBitmapContextGetData(context) -- width:width -- height:height]; -+ buffer_to_return = [UIImage -+ populateRGBBufferFromSourceRGBABuffer:CGBitmapContextGetData(context) -+ width:width -+ height:height]; +- uint8_t *srcData = CGBitmapContextGetData(context); ++ uint8_t* srcData = CGBitmapContextGetData(context); + + if (srcData) { +- // We have drawn the image as an RGBA image with 8 bitsPerComponent and hence can safely input +- // a pixel format of type kCVPixelFormatType_32RGBA for conversion by vImage. +- buffer_to_return = +- [TFLCVPixelBufferUtils createRGBImageDatafromImageData:srcData +- withWidth:width +- height:height +- stride:bytesPerRow +- pixelBufferFormat:kCVPixelFormatType_32RGBA +- error:error]; ++ // We have drawn the image as an RGBA image with 8 bitsPerComponent and ++ // hence can safely input a pixel format of type kCVPixelFormatType_32RGBA ++ // for conversion by vImage. ++ buffer_to_return = [TFLCVPixelBufferUtils ++ createRGBImageDatafromImageData:srcData ++ withWidth:width ++ height:height ++ stride:bytesPerRow ++ pixelBufferFormat:kCVPixelFormatType_32RGBA ++ error:error]; + } + CGContextRelease(context); - } - -@@ -244,16 +268,18 @@ - return buffer_to_return; - } - --+ (nullable UInt8 *)populateRGBBufferFromSourceRGBABuffer:(UInt8 *)buffer -- width:(size_t)width -- height:(size_t)height { -- if (!buffer) return NULL; -++ (nullable UInt8*)populateRGBBufferFromSourceRGBABuffer:(UInt8*)buffer -+ width:(size_t)width -+ height:(size_t)height { -+ if (!buffer) -+ return NULL; - - int sourceChannelCount = 4; - int destChannelCount = 3; - -- UInt8 *buffer_to_return = -- [TFLCommonUtils mallocWithSize:sizeof(UInt8) * height * destChannelCount * width error:nil]; -+ UInt8* buffer_to_return = [TFLCommonUtils -+ mallocWithSize:sizeof(UInt8) * height * destChannelCount * width -+ error:nil]; - if (!buffer_to_return) { - return NULL; - } -@@ -269,28 +295,32 @@ +@@ -265,18 +285,21 @@ return buffer_to_return; } @@ -10424,10 +11913,6 @@ + error:(NSError**)error { + UInt8* buffer = [UIImage pixelDataFromCGImage:cgImage error:error]; - if (buffer == NULL) { - return NULL; - } - - return [TFLCVPixelBufferUtils cFrameBufferWithWidth:(int)CGImageGetWidth(cgImage) - height:(int)CGImageGetHeight(cgImage) - frameBufferFormat:kRGB @@ -10449,16 +11934,18 @@ int width = 0; int height = 0; - if (ciImage.pixelBuffer) { -- buffer = [TFLCVPixelBufferUtils convertBGRAtoRGBforPixelBufferBaseAddress:ciImage.pixelBuffer -- error:error]; -+ buffer = [TFLCVPixelBufferUtils -+ convertBGRAtoRGBforPixelBufferBaseAddress:ciImage.pixelBuffer -+ error:error]; +@@ -285,17 +308,20 @@ width = (int)CVPixelBufferGetWidth(ciImage.pixelBuffer); height = (int)CVPixelBufferGetHeight(ciImage.pixelBuffer); -@@ -299,9 +329,11 @@ +- buffer = [TFLCVPixelBufferUtils createRGBImageDatafromCVPixelBuffer:ciImage.pixelBuffer +- error:error]; ++ buffer = [TFLCVPixelBufferUtils ++ createRGBImageDatafromCVPixelBuffer:ciImage.pixelBuffer ++ error:error]; + + } else if (ciImage.CGImage) { + buffer = [UIImage pixelDataFromCGImage:ciImage.CGImage error:error]; width = (int)CGImageGetWidth(ciImage.CGImage); height = (int)CGImageGetWidth(ciImage.CGImage); } else { @@ -10472,8 +11959,8 @@ + @"CIImage should have CGImage or CVPixelBuffer info."]; } - if (buffer == NULL) { -@@ -319,19 +351,23 @@ + return [TFLCVPixelBufferUtils cFrameBufferWithWidth:width +@@ -309,19 +335,23 @@ @implementation GMLImage (Utils) @@ -10504,7 +11991,7 @@ break; } case GMLImageSourceTypeImage: { -@@ -362,14 +398,17 @@ +@@ -352,14 +382,17 @@ return nil; } @@ -10529,6 +12016,319 @@ return [[GMLImage alloc] initWithImage:image]; } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/audio/core/TFLRingBufferTests.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/audio/core/TFLRingBufferTests.m +index 3e2df5d4bf023..cd389b9c0a9a8 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/audio/core/TFLRingBufferTests.m ++++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/audio/core/TFLRingBufferTests.m +@@ -17,10 +17,12 @@ + #import "tensorflow_lite_support/ios/sources/TFLCommon.h" + #import "tensorflow_lite_support/ios/task/audio/core/sources/TFLRingBuffer.h" + +-#define VerifyError(error, expectedErrorDomain, expectedErrorCode, expectedLocalizedDescription) \ +- XCTAssertEqual(error.domain, expectedErrorDomain); \ +- XCTAssertEqual(error.code, expectedErrorCode); \ +- XCTAssertEqualObjects(error.localizedDescription, expectedLocalizedDescription); ++#define VerifyError(error, expectedErrorDomain, expectedErrorCode, \ ++ expectedLocalizedDescription) \ ++ XCTAssertEqual(error.domain, expectedErrorDomain); \ ++ XCTAssertEqual(error.code, expectedErrorCode); \ ++ XCTAssertEqualObjects(error.localizedDescription, \ ++ expectedLocalizedDescription); + + NS_ASSUME_NONNULL_BEGIN + +@@ -33,15 +35,20 @@ NS_ASSUME_NONNULL_BEGIN + NSInteger inDataLength = 5; + float inData[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}; + +- TFLFloatBuffer *inBuffer = [[TFLFloatBuffer alloc] initWithData:&(inData[0]) size:inDataLength]; ++ TFLFloatBuffer* inBuffer = ++ [[TFLFloatBuffer alloc] initWithData:&(inData[0]) size:inDataLength]; + + NSInteger bufferSize = 5; +- TFLRingBuffer *ringBuffer = [[TFLRingBuffer alloc] initWithBufferSize:bufferSize]; ++ TFLRingBuffer* ringBuffer = ++ [[TFLRingBuffer alloc] initWithBufferSize:bufferSize]; + +- XCTAssertTrue([ringBuffer loadBuffer:inBuffer offset:0 size:inDataLength error:nil]); ++ XCTAssertTrue([ringBuffer loadBuffer:inBuffer ++ offset:0 ++ size:inDataLength ++ error:nil]); + // State after load: [1.0, 2.0, 3.0, 4.0, 5.0] + +- TFLFloatBuffer *outBuffer = ringBuffer.floatBuffer; ++ TFLFloatBuffer* outBuffer = ringBuffer.floatBuffer; + XCTAssertNotNil(outBuffer); + XCTAssertEqual(outBuffer.size, bufferSize); + +@@ -55,16 +62,21 @@ NS_ASSUME_NONNULL_BEGIN + - (void)testLoadSucceedsWithPartialLengthBuffer { + NSInteger inDataSize = 3; + float inData[] = {1.0f, 2.0f, 3.0f}; +- TFLFloatBuffer *inBuffer = [[TFLFloatBuffer alloc] initWithData:&(inData[0]) size:inDataSize]; ++ TFLFloatBuffer* inBuffer = ++ [[TFLFloatBuffer alloc] initWithData:&(inData[0]) size:inDataSize]; + + NSInteger bufferSize = 5; +- TFLRingBuffer *ringBuffer = [[TFLRingBuffer alloc] initWithBufferSize:bufferSize]; ++ TFLRingBuffer* ringBuffer = ++ [[TFLRingBuffer alloc] initWithBufferSize:bufferSize]; + +- XCTAssertTrue([ringBuffer loadBuffer:inBuffer offset:0 size:inDataSize error:nil]); ++ XCTAssertTrue([ringBuffer loadBuffer:inBuffer ++ offset:0 ++ size:inDataSize ++ error:nil]); + + // State after load: [0.0, 0.0, 1.0, 2.0, 3.0] + +- TFLFloatBuffer *outBuffer = ringBuffer.floatBuffer; ++ TFLFloatBuffer* outBuffer = ringBuffer.floatBuffer; + XCTAssertNotNil(outBuffer); + XCTAssertEqual(outBuffer.size, bufferSize); + +@@ -80,23 +92,32 @@ NS_ASSUME_NONNULL_BEGIN + NSInteger initialDataSize = 4; + float initialArray[] = {1.0f, 2.0f, 3.0f, 4.0f}; + +- TFLFloatBuffer *initialBuffer = +- [[TFLFloatBuffer alloc] initWithData:&(initialArray[0]) size:initialDataSize]; ++ TFLFloatBuffer* initialBuffer = ++ [[TFLFloatBuffer alloc] initWithData:&(initialArray[0]) ++ size:initialDataSize]; + + NSInteger bufferSize = 5; +- TFLRingBuffer *ringBuffer = [[TFLRingBuffer alloc] initWithBufferSize:bufferSize]; ++ TFLRingBuffer* ringBuffer = ++ [[TFLRingBuffer alloc] initWithBufferSize:bufferSize]; + +- XCTAssertTrue([ringBuffer loadBuffer:initialBuffer offset:0 size:initialDataSize error:nil]); ++ XCTAssertTrue([ringBuffer loadBuffer:initialBuffer ++ offset:0 ++ size:initialDataSize ++ error:nil]); + + // State after load: [0.0, 1.0, 2.0, 3.0, 4.0] + + NSInteger inDataSize = 3; + float inArray[] = {5, 6, 7}; +- TFLFloatBuffer *inBuffer = [[TFLFloatBuffer alloc] initWithData:&(inArray[0]) size:inDataSize]; ++ TFLFloatBuffer* inBuffer = ++ [[TFLFloatBuffer alloc] initWithData:&(inArray[0]) size:inDataSize]; + +- XCTAssertTrue([ringBuffer loadBuffer:inBuffer offset:0 size:inDataSize error:nil]); ++ XCTAssertTrue([ringBuffer loadBuffer:inBuffer ++ offset:0 ++ size:inDataSize ++ error:nil]); + +- TFLFloatBuffer *outBuffer = ringBuffer.floatBuffer; ++ TFLFloatBuffer* outBuffer = ringBuffer.floatBuffer; + XCTAssertNotNil(outBuffer); + XCTAssertEqual(outBuffer.size, bufferSize); + +@@ -112,24 +133,33 @@ NS_ASSUME_NONNULL_BEGIN + NSInteger initialDataSize = 5; + float initialArray[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}; + +- TFLFloatBuffer *initialBuffer = +- [[TFLFloatBuffer alloc] initWithData:&(initialArray[0]) size:initialDataSize]; ++ TFLFloatBuffer* initialBuffer = ++ [[TFLFloatBuffer alloc] initWithData:&(initialArray[0]) ++ size:initialDataSize]; + + NSInteger bufferSize = 5; +- TFLRingBuffer *ringBuffer = [[TFLRingBuffer alloc] initWithBufferSize:bufferSize]; ++ TFLRingBuffer* ringBuffer = ++ [[TFLRingBuffer alloc] initWithBufferSize:bufferSize]; + +- XCTAssertTrue([ringBuffer loadBuffer:initialBuffer offset:0 size:initialDataSize error:nil]); ++ XCTAssertTrue([ringBuffer loadBuffer:initialBuffer ++ offset:0 ++ size:initialDataSize ++ error:nil]); + + // State after load: [1.0, 2.0, 3.0, 4.0, 5.0] + + NSInteger sourceDataSize = 6; + float sourceArray[] = {6, 7, 8, 9, 10, 11}; +- TFLFloatBuffer *sourceBuffer = +- [[TFLFloatBuffer alloc] initWithData:&(sourceArray[0]) size:sourceDataSize]; ++ TFLFloatBuffer* sourceBuffer = ++ [[TFLFloatBuffer alloc] initWithData:&(sourceArray[0]) ++ size:sourceDataSize]; + +- XCTAssertTrue([ringBuffer loadBuffer:sourceBuffer offset:0 size:sourceDataSize error:nil]); ++ XCTAssertTrue([ringBuffer loadBuffer:sourceBuffer ++ offset:0 ++ size:sourceDataSize ++ error:nil]); + +- TFLFloatBuffer *outBuffer = ringBuffer.floatBuffer; ++ TFLFloatBuffer* outBuffer = ringBuffer.floatBuffer; + XCTAssertNotNil(outBuffer); + XCTAssertEqual(outBuffer.size, bufferSize); + +@@ -145,25 +175,34 @@ NS_ASSUME_NONNULL_BEGIN + NSInteger initialDataSize = 5; + float initialArray[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}; + +- TFLFloatBuffer *initialBuffer = +- [[TFLFloatBuffer alloc] initWithData:&(initialArray[0]) size:initialDataSize]; ++ TFLFloatBuffer* initialBuffer = ++ [[TFLFloatBuffer alloc] initWithData:&(initialArray[0]) ++ size:initialDataSize]; + + NSInteger bufferSize = 5; +- TFLRingBuffer *ringBuffer = [[TFLRingBuffer alloc] initWithBufferSize:bufferSize]; ++ TFLRingBuffer* ringBuffer = ++ [[TFLRingBuffer alloc] initWithBufferSize:bufferSize]; + +- XCTAssertTrue([ringBuffer loadBuffer:initialBuffer offset:0 size:initialDataSize error:nil]); ++ XCTAssertTrue([ringBuffer loadBuffer:initialBuffer ++ offset:0 ++ size:initialDataSize ++ error:nil]); + + // State after load: [1.0, 2.0, 3.0, 4.0, 5.0] + + NSInteger totalInSize = 8; + float inArray[] = {6, 7, 8, 9, 10, 11, 12, 13}; +- TFLFloatBuffer *inBuffer = [[TFLFloatBuffer alloc] initWithData:&(inArray[0]) size:totalInSize]; ++ TFLFloatBuffer* inBuffer = ++ [[TFLFloatBuffer alloc] initWithData:&(inArray[0]) size:totalInSize]; + + NSInteger offset = 2; + NSInteger inDataSize = 6; +- XCTAssertTrue([ringBuffer loadBuffer:inBuffer offset:offset size:inDataSize error:nil]); ++ XCTAssertTrue([ringBuffer loadBuffer:inBuffer ++ offset:offset ++ size:inDataSize ++ error:nil]); + +- TFLFloatBuffer *outBuffer = ringBuffer.floatBuffer; ++ TFLFloatBuffer* outBuffer = ringBuffer.floatBuffer; + XCTAssertNotNil(outBuffer); + XCTAssertEqual(outBuffer.size, bufferSize); + +@@ -179,25 +218,34 @@ NS_ASSUME_NONNULL_BEGIN + NSInteger initialDataSize = 2; + float initialArray[] = {1.0f, 2.0f}; + +- TFLFloatBuffer *initialBuffer = +- [[TFLFloatBuffer alloc] initWithData:&(initialArray[0]) size:initialDataSize]; ++ TFLFloatBuffer* initialBuffer = ++ [[TFLFloatBuffer alloc] initWithData:&(initialArray[0]) ++ size:initialDataSize]; + + NSInteger bufferSize = 5; +- TFLRingBuffer *ringBuffer = [[TFLRingBuffer alloc] initWithBufferSize:bufferSize]; ++ TFLRingBuffer* ringBuffer = ++ [[TFLRingBuffer alloc] initWithBufferSize:bufferSize]; + +- XCTAssertTrue([ringBuffer loadBuffer:initialBuffer offset:0 size:initialDataSize error:nil]); ++ XCTAssertTrue([ringBuffer loadBuffer:initialBuffer ++ offset:0 ++ size:initialDataSize ++ error:nil]); + + // State after load: [0.0, 0.0, 0.0, 1.0, 2.0] + + NSInteger totalInSize = 4; + float inArray[] = {6.0f, 7.0f, 8.0f, 9.0f}; +- TFLFloatBuffer *inBuffer = [[TFLFloatBuffer alloc] initWithData:&(inArray[0]) size:totalInSize]; ++ TFLFloatBuffer* inBuffer = ++ [[TFLFloatBuffer alloc] initWithData:&(inArray[0]) size:totalInSize]; + + NSInteger offset = 2; + NSInteger inDataSize = 2; +- XCTAssertTrue([ringBuffer loadBuffer:inBuffer offset:offset size:inDataSize error:nil]); ++ XCTAssertTrue([ringBuffer loadBuffer:inBuffer ++ offset:offset ++ size:inDataSize ++ error:nil]); + +- TFLFloatBuffer *outBuffer = ringBuffer.floatBuffer; ++ TFLFloatBuffer* outBuffer = ringBuffer.floatBuffer; + XCTAssertNotNil(outBuffer); + XCTAssertEqual(outBuffer.size, bufferSize); + +@@ -213,26 +261,36 @@ NS_ASSUME_NONNULL_BEGIN + NSInteger initialDataSize = 2; + float initialArray[] = {1.0f, 2.0f}; + +- TFLFloatBuffer *initialBuffer = +- [[TFLFloatBuffer alloc] initWithData:&(initialArray[0]) size:initialDataSize]; ++ TFLFloatBuffer* initialBuffer = ++ [[TFLFloatBuffer alloc] initWithData:&(initialArray[0]) ++ size:initialDataSize]; + + NSInteger bufferSize = 5; +- TFLRingBuffer *ringBuffer = [[TFLRingBuffer alloc] initWithBufferSize:bufferSize]; ++ TFLRingBuffer* ringBuffer = ++ [[TFLRingBuffer alloc] initWithBufferSize:bufferSize]; + +- XCTAssertTrue([ringBuffer loadBuffer:initialBuffer offset:0 size:initialDataSize error:nil]); ++ XCTAssertTrue([ringBuffer loadBuffer:initialBuffer ++ offset:0 ++ size:initialDataSize ++ error:nil]); + + NSInteger totalInSize = 4; + float inArray[] = {6.0f, 7.0f, 8.0f, 9.0f}; +- TFLFloatBuffer *inBuffer = [[TFLFloatBuffer alloc] initWithData:&(inArray[0]) size:totalInSize]; ++ TFLFloatBuffer* inBuffer = ++ [[TFLFloatBuffer alloc] initWithData:&(inArray[0]) size:totalInSize]; + + NSInteger offset = 2; + NSInteger inDataSize = 3; + +- NSError *error = nil; +- XCTAssertFalse([ringBuffer loadBuffer:inBuffer offset:offset size:inDataSize error:&error]); ++ NSError* error = nil; ++ XCTAssertFalse([ringBuffer loadBuffer:inBuffer ++ offset:offset ++ size:inDataSize ++ error:&error]); + + XCTAssertNotNil(error); +- VerifyError(error, @"org.tensorflow.lite.tasks", TFLSupportErrorCodeInvalidArgumentError, ++ VerifyError(error, @"org.tensorflow.lite.tasks", ++ TFLSupportErrorCodeInvalidArgumentError, + @"offset + size exceeds the maximum size of the source buffer."); + } + +@@ -240,19 +298,24 @@ NS_ASSUME_NONNULL_BEGIN + NSInteger initialDataSize = 2; + float initialArray[] = {1.0f, 2.0f}; + +- TFLFloatBuffer *initialBuffer = +- [[TFLFloatBuffer alloc] initWithData:&(initialArray[0]) size:initialDataSize]; ++ TFLFloatBuffer* initialBuffer = ++ [[TFLFloatBuffer alloc] initWithData:&(initialArray[0]) ++ size:initialDataSize]; + + NSInteger bufferSize = 5; +- TFLRingBuffer *ringBuffer = [[TFLRingBuffer alloc] initWithBufferSize:bufferSize]; ++ TFLRingBuffer* ringBuffer = ++ [[TFLRingBuffer alloc] initWithBufferSize:bufferSize]; + +- XCTAssertTrue([ringBuffer loadBuffer:initialBuffer offset:0 size:initialDataSize error:nil]); ++ XCTAssertTrue([ringBuffer loadBuffer:initialBuffer ++ offset:0 ++ size:initialDataSize ++ error:nil]); + + [ringBuffer clear]; + + float expectedData[] = {0.0f, 0.0f, 0.0f, 0.0f, 0.0f}; + +- TFLFloatBuffer *outBuffer = ringBuffer.floatBuffer; ++ TFLFloatBuffer* outBuffer = ringBuffer.floatBuffer; + XCTAssertNotNil(outBuffer); + XCTAssertEqual(outBuffer.size, bufferSize); + diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/image_classifier/TFLImageClassifierTests.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/image_classifier/TFLImageClassifierTests.m index d03b6044bdd68..b1ed8cf1e2f6d 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/image_classifier/TFLImageClassifierTests.m @@ -10602,7 +12402,7 @@ TFLClassificationResult *classificationResults = [imageClassifier classifyWithGMLImage:gmlImage diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/image_segmenter/TFLImageSegmenterTests.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/image_segmenter/TFLImageSegmenterTests.m -index a7ba81f4c1b2d..39c3153f82539 100644 +index c2977475f6d4f..f483a516b9bc6 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/image_segmenter/TFLImageSegmenterTests.m +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/image_segmenter/TFLImageSegmenterTests.m @@ -18,10 +18,11 @@ @@ -10638,11 +12438,11 @@ + // Put setup code here. This method is called before the invocation of each + // test method in the class. [super setUp]; -- self.modelPath = [[NSBundle bundleForClass:[self class]] pathForResource:@"deeplabv3" -- ofType:@"tflite"]; +- self.modelPath = [[NSBundle bundleForClass:self.class] pathForResource:@"deeplabv3" +- ofType:@"tflite"]; + self.modelPath = -+ [[NSBundle bundleForClass:[self class]] pathForResource:@"deeplabv3" -+ ofType:@"tflite"]; ++ [[NSBundle bundleForClass:self.class] pathForResource:@"deeplabv3" ++ ofType:@"tflite"]; XCTAssertNotNil(self.modelPath); } @@ -10667,11 +12467,11 @@ + error:nil]; XCTAssertNotNil(imageSegmenter); -- GMLImage *gmlImage = [GMLImage imageFromBundleWithClass:[self class] +- GMLImage *gmlImage = [GMLImage imageFromBundleWithClass:self.class - fileName:@"segmentation_input_rotation0" - ofType:@"jpg"]; + GMLImage* gmlImage = -+ [GMLImage imageFromBundleWithClass:[self class] ++ [GMLImage imageFromBundleWithClass:self.class + fileName:@"segmentation_input_rotation0" + ofType:@"jpg"]; XCTAssertNotNil(gmlImage); @@ -10693,11 +12493,11 @@ XCTAssertNotNil(segmentationResult.segmentations[0].categoryMask); XCTAssertTrue(segmentationResult.segmentations[0].categoryMask.mask != nil); -- GMLImage *goldenImage = [GMLImage imageFromBundleWithClass:[self class] +- GMLImage *goldenImage = [GMLImage imageFromBundleWithClass:self.class - fileName:@"segmentation_golden_rotation0" - ofType:@"png"]; + GMLImage* goldenImage = -+ [GMLImage imageFromBundleWithClass:[self class] ++ [GMLImage imageFromBundleWithClass:self.class + fileName:@"segmentation_golden_rotation0" + ofType:@"png"]; @@ -10737,17 +12537,66 @@ @end diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/object_detector/TFLObjectDetectorTests.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/object_detector/TFLObjectDetectorTests.m -index e84b51063853b..7dffe7e15015d 100644 +index f6820f335e18b..f7091a5995b02 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/object_detector/TFLObjectDetectorTests.m +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/object_detector/TFLObjectDetectorTests.m -@@ -18,7 +18,6 @@ +@@ -18,16 +18,22 @@ #import "tensorflow_lite_support/ios/task/vision/sources/TFLObjectDetector.h" #import "tensorflow_lite_support/ios/task/vision/utils/sources/GMLImage+Utils.h" -- - #define VerifyDetection(detection, expectedBoundingBox, expectedFirstScore, expectedFirstLabel) \ - XCTAssertGreaterThan([detection.categories count], 0); \ - NSLog(@"Detected %f", detection.categories[0].score); \ +-#define VerifyDetection(detection, expectedBoundingBox, expectedFirstScore, expectedFirstLabel) \ +- XCTAssertGreaterThan(detection.categories.count, 0); \ +- NSLog(@"Detected %f", detection.categories[0].score); \ +- NSLog(@"Expected %f", expectedFirstScore); \ +- XCTAssertEqual(detection.boundingBox.origin.x, expectedBoundingBox.origin.x); \ +- XCTAssertEqual(detection.boundingBox.origin.y, expectedBoundingBox.origin.y); \ +- XCTAssertEqual(detection.boundingBox.size.width, expectedBoundingBox.size.width); \ +- XCTAssertEqual(detection.boundingBox.size.height, expectedBoundingBox.size.height); \ +- XCTAssertEqualObjects(detection.categories[0].label, expectedFirstLabel); \ +- XCTAssertEqualWithAccuracy(detection.categories[0].score, expectedFirstScore, 0.001) ++#define VerifyDetection(detection, expectedBoundingBox, expectedFirstScore, \ ++ expectedFirstLabel) \ ++ XCTAssertGreaterThan(detection.categories.count, 0); \ ++ NSLog(@"Detected %f", detection.categories[0].score); \ ++ NSLog(@"Expected %f", expectedFirstScore); \ ++ XCTAssertEqual(detection.boundingBox.origin.x, \ ++ expectedBoundingBox.origin.x); \ ++ XCTAssertEqual(detection.boundingBox.origin.y, \ ++ expectedBoundingBox.origin.y); \ ++ XCTAssertEqual(detection.boundingBox.size.width, \ ++ expectedBoundingBox.size.width); \ ++ XCTAssertEqual(detection.boundingBox.size.height, \ ++ expectedBoundingBox.size.height); \ ++ XCTAssertEqualObjects(detection.categories[0].label, expectedFirstLabel); \ ++ XCTAssertEqualWithAccuracy(detection.categories[0].score, \ ++ expectedFirstScore, 0.001) + + @interface TFLObjectDetectorTests : XCTestCase + @property(nonatomic, nullable) NSString *modelPath; +@@ -77,8 +83,9 @@ + [TFLObjectDetector objectDetectorWithOptions:objectDetectorOptions error:nil]; + XCTAssertNotNil(objectDetector); + +- GMLImage *gmlImage = +- [GMLImage imageFromBundleWithClass:self.class fileName:@"cats_and_dogs" ofType:@"jpg"]; ++ GMLImage* gmlImage = [GMLImage imageFromBundleWithClass:self.class ++ fileName:@"cats_and_dogs" ++ ofType:@"jpg"]; + XCTAssertNotNil(gmlImage); + + TFLDetectionResult *detectionResults = [objectDetector detectWithGMLImage:gmlImage error:nil]; +@@ -95,8 +102,9 @@ + [TFLObjectDetector objectDetectorWithOptions:objectDetectorOptions error:nil]; + XCTAssertNotNil(objectDetector); + +- GMLImage *gmlImage = +- [GMLImage imageFromBundleWithClass:self.class fileName:@"cats_and_dogs" ofType:@"jpg"]; ++ GMLImage* gmlImage = [GMLImage imageFromBundleWithClass:self.class ++ fileName:@"cats_and_dogs" ++ ofType:@"jpg"]; + XCTAssertNotNil(gmlImage); + + TFLDetectionResult *detectionResult = [objectDetector detectWithGMLImage:gmlImage error:nil]; diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLBertTokenizer.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLBertTokenizer.h index ed679c22a467b..c10c82afc1913 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLBertTokenizer.h @@ -19321,7 +21170,7 @@ - } } diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/AudioClassifier.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/AudioClassifier.java -index 4f15f3d6b7d64..b3eb11fb32f5f 100644 +index 043528aa88138..85c5d12e2fc53 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/AudioClassifier.java +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/AudioClassifier.java @@ -22,13 +22,7 @@ import android.media.AudioFormat; @@ -19337,11 +21186,11 @@ -import java.util.List; + import org.tensorflow.lite.DataType; - import org.tensorflow.lite.annotations.UsedByReflection; import org.tensorflow.lite.support.audio.TensorAudio; -@@ -40,6 +34,14 @@ import org.tensorflow.lite.task.core.TaskJniUtils; - import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider; + import org.tensorflow.lite.support.audio.TensorAudio.TensorAudioFormat; +@@ -40,6 +34,14 @@ import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider; import org.tensorflow.lite.task.core.TaskJniUtils.FdAndOptionsHandleProvider; + import org.tensorflow.lite.task.core.annotations.UsedByReflection; +import java.io.File; +import java.io.IOException; @@ -19354,7 +21203,7 @@ /** * Performs classification on audio waveforms. * -@@ -72,468 +74,437 @@ import org.tensorflow.lite.task.core.TaskJniUtils.FdAndOptionsHandleProvider; +@@ -72,468 +74,437 @@ import org.tensorflow.lite.task.core.annotations.UsedByReflection; * CLI demo tool</a> for easily trying out this API. */ public final class AudioClassifier extends BaseTaskApi { @@ -20229,7 +22078,7 @@ + private static native void deinitJni(long nativeHandle); } diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/Classifications.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/Classifications.java -index 446d328441a97..7d5b07fa735cd 100644 +index 9c0cdf9e249ae..8e8270269dad8 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/Classifications.java +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/Classifications.java @@ -16,11 +16,13 @@ limitations under the License. @@ -20237,18 +22086,18 @@ import com.google.auto.value.AutoValue; + -+import org.tensorflow.lite.annotations.UsedByReflection; +import org.tensorflow.lite.support.label.Category; ++import org.tensorflow.lite.task.core.annotations.UsedByReflection; + import java.util.ArrayList; import java.util.Collections; import java.util.List; --import org.tensorflow.lite.annotations.UsedByReflection; -import org.tensorflow.lite.support.label.Category; +-import org.tensorflow.lite.task.core.annotations.UsedByReflection; /** * The classification results of one head in a multihead (a.k.a. multi-output) {@link -@@ -31,18 +33,18 @@ import org.tensorflow.lite.support.label.Category; +@@ -31,18 +33,18 @@ import org.tensorflow.lite.task.core.annotations.UsedByReflection; @AutoValue @UsedByReflection("audio_classifier_jni.cc") public abstract class Classifications { @@ -20911,6 +22760,17 @@ - private static native long createProtoBaseOptions(int delegate, int numThreads); + private static native long createProtoBaseOptions(int delegate, int numThreads); } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/annotations/UsedByReflection.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/annotations/UsedByReflection.java +index bfa1ea750cf1f..fb1dfec82d7b4 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/annotations/UsedByReflection.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/annotations/UsedByReflection.java +@@ -27,5 +27,5 @@ import java.lang.annotation.Target; + */ + @Target({ElementType.METHOD, ElementType.FIELD, ElementType.TYPE, ElementType.CONSTRUCTOR}) + public @interface UsedByReflection { +- String value(); ++ String value(); + } diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/vision/ImageProcessingOptions.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/vision/ImageProcessingOptions.java index 287ba444c386b..b1784d02f2362 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/vision/ImageProcessingOptions.java @@ -21060,8 +22920,206 @@ } - } } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/processor/NearestNeighbor.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/processor/NearestNeighbor.java +index f5cc5af615117..a39247f1239c8 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/processor/NearestNeighbor.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/processor/NearestNeighbor.java +@@ -16,37 +16,38 @@ limitations under the License. + package org.tensorflow.lite.task.processor; + + import com.google.auto.value.AutoValue; ++ ++import org.tensorflow.lite.task.core.annotations.UsedByReflection; ++ + import java.nio.ByteBuffer; + import java.nio.ByteOrder; +-import org.tensorflow.lite.task.core.annotations.UsedByReflection; + + /** Represents the search result of a Searcher model. */ + @AutoValue + @UsedByReflection("searcher_jni.cc") + public abstract class NearestNeighbor { +- +- @UsedByReflection("searcher_jni.cc") +- static NearestNeighbor create(byte[] metadataArray, float distance) { +- // Convert byte[] metadataArray to ByteBuffer which handles endianess better. +- // +- // Ideally, the API should accept a ByteBuffer instead of a byte[]. However, converting byte[] +- // to ByteBuffer in JNI will lead to unnecessarily complex code which involves 6 more reflection +- // calls. We can make this method package private, because users in general shouldn't need to +- // create NearestNeighbor instances, but only consume the objects return from Task Library. This +- // API will be used mostly for internal purpose. +- ByteBuffer metadata = ByteBuffer.wrap(metadataArray); +- metadata.order(ByteOrder.nativeOrder()); +- return new AutoValue_NearestNeighbor(metadata, distance); +- } +- +- /** +- * Gets the user-defined metadata about the result. This could be a label, a unique ID, a +- * serialized proto of some sort, etc. +- * +- * <p><b>Do not mutate</b> the returned metadata. +- */ +- public abstract ByteBuffer getMetadata(); +- +- /** Gets the distance score indicating how confident the result is. Lower is better. */ +- public abstract float getDistance(); ++ @UsedByReflection("searcher_jni.cc") ++ static NearestNeighbor create(byte[] metadataArray, float distance) { ++ // Convert byte[] metadataArray to ByteBuffer which handles endianess better. ++ // ++ // Ideally, the API should accept a ByteBuffer instead of a byte[]. However, converting ++ // byte[] to ByteBuffer in JNI will lead to unnecessarily complex code which involves 6 more ++ // reflection calls. We can make this method package private, because users in general ++ // shouldn't need to create NearestNeighbor instances, but only consume the objects return ++ // from Task Library. This API will be used mostly for internal purpose. ++ ByteBuffer metadata = ByteBuffer.wrap(metadataArray); ++ metadata.order(ByteOrder.nativeOrder()); ++ return new AutoValue_NearestNeighbor(metadata, distance); ++ } ++ ++ /** ++ * Gets the user-defined metadata about the result. This could be a label, a unique ID, a ++ * serialized proto of some sort, etc. ++ * ++ * <p><b>Do not mutate</b> the returned metadata. ++ */ ++ public abstract ByteBuffer getMetadata(); ++ ++ /** Gets the distance score indicating how confident the result is. Lower is better. */ ++ public abstract float getDistance(); + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/processor/SearcherOptions.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/processor/SearcherOptions.java +index fa601edf92b30..86f5fdde0187c 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/processor/SearcherOptions.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/processor/SearcherOptions.java +@@ -16,66 +16,68 @@ limitations under the License. + package org.tensorflow.lite.task.processor; + + import androidx.annotation.Nullable; ++ + import com.google.auto.value.AutoValue; ++ + import java.io.File; + + /** Options to configure Searcher API. */ + @AutoValue + public abstract class SearcherOptions { +- private static final boolean DEFAULT_L2_NORMALIZE = false; +- private static final boolean DEFAULT_QUANTIZE = false; +- private static final int DEFAULT_MAX_RESULTS = 5; +- +- public abstract boolean getL2Normalize(); +- +- public abstract boolean getQuantize(); +- +- @Nullable +- public abstract File getIndexFile(); +- +- public abstract int getMaxResults(); +- +- public static Builder builder() { +- return new AutoValue_SearcherOptions.Builder() +- .setL2Normalize(DEFAULT_L2_NORMALIZE) +- .setQuantize(DEFAULT_QUANTIZE) +- .setIndexFile(null) +- .setMaxResults(DEFAULT_MAX_RESULTS); +- } +- +- /** Builder for {@link SearcherOptions}. */ +- @AutoValue.Builder +- public abstract static class Builder { +- /** +- * Sets whether to normalize the embedding feature vector with L2 norm. Defaults to false. +- * +- * <p>Use this option only if the model does not already contain a native L2_NORMALIZATION +- * TFLite Op. In most cases, this is already the case and L2 norm is thus achieved through +- * TFLite inference. +- */ +- public abstract Builder setL2Normalize(boolean l2Normalize); +- +- /** +- * Sets whether the embedding should be quantized to bytes via scalar quantization. Defaults to +- * false. +- * +- * <p>Embeddings are implicitly assumed to be unit-norm and therefore any dimension is +- * guaranteed to have a value in {@code [-1.0, 1.0]}. Use the l2_normalize option if this is not +- * the case. +- */ +- public abstract Builder setQuantize(boolean quantize); +- +- /** +- * Sets the index file to search into. +- * +- * <p>Required if the model does not come with an index file inside. Otherwise, it can be ignore +- * by setting to {@code null}. +- */ +- public abstract Builder setIndexFile(@Nullable File indexFile); +- +- /** Sets the maximum number of nearest neighbor results to return. Defaults to {@code 5} */ +- public abstract Builder setMaxResults(int maxResults); +- +- public abstract SearcherOptions build(); +- } ++ private static final boolean DEFAULT_L2_NORMALIZE = false; ++ private static final boolean DEFAULT_QUANTIZE = false; ++ private static final int DEFAULT_MAX_RESULTS = 5; ++ ++ public abstract boolean getL2Normalize(); ++ ++ public abstract boolean getQuantize(); ++ ++ @Nullable ++ public abstract File getIndexFile(); ++ ++ public abstract int getMaxResults(); ++ ++ public static Builder builder() { ++ return new AutoValue_SearcherOptions.Builder() ++ .setL2Normalize(DEFAULT_L2_NORMALIZE) ++ .setQuantize(DEFAULT_QUANTIZE) ++ .setIndexFile(null) ++ .setMaxResults(DEFAULT_MAX_RESULTS); ++ } ++ ++ /** Builder for {@link SearcherOptions}. */ ++ @AutoValue.Builder ++ public abstract static class Builder { ++ /** ++ * Sets whether to normalize the embedding feature vector with L2 norm. Defaults to false. ++ * ++ * <p>Use this option only if the model does not already contain a native L2_NORMALIZATION ++ * TFLite Op. In most cases, this is already the case and L2 norm is thus achieved through ++ * TFLite inference. ++ */ ++ public abstract Builder setL2Normalize(boolean l2Normalize); ++ ++ /** ++ * Sets whether the embedding should be quantized to bytes via scalar quantization. Defaults ++ * to false. ++ * ++ * <p>Embeddings are implicitly assumed to be unit-norm and therefore any dimension is ++ * guaranteed to have a value in {@code [-1.0, 1.0]}. Use the l2_normalize option if this is ++ * not the case. ++ */ ++ public abstract Builder setQuantize(boolean quantize); ++ ++ /** ++ * Sets the index file to search into. ++ * ++ * <p>Required if the model does not come with an index file inside. Otherwise, it can be ++ * ignore by setting to {@code null}. ++ */ ++ public abstract Builder setIndexFile(@Nullable File indexFile); ++ ++ /** Sets the maximum number of nearest neighbor results to return. Defaults to {@code 5} */ ++ public abstract Builder setMaxResults(int maxResults); ++ ++ public abstract SearcherOptions build(); ++ } + } diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifier.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifier.java -index d0ac3f83b4ed5..ce912c96e29de 100644 +index 55743055ff408..070b945e72b90 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifier.java +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifier.java @@ -17,12 +17,9 @@ package org.tensorflow.lite.task.text.nlclassifier; @@ -21076,12 +23134,12 @@ -import java.nio.MappedByteBuffer; -import java.util.List; + - import org.tensorflow.lite.annotations.UsedByReflection; import org.tensorflow.lite.support.label.Category; import org.tensorflow.lite.task.core.BaseOptions; -@@ -30,6 +27,12 @@ import org.tensorflow.lite.task.core.BaseTaskApi; - import org.tensorflow.lite.task.core.TaskJniUtils; + import org.tensorflow.lite.task.core.BaseTaskApi; +@@ -30,6 +27,12 @@ import org.tensorflow.lite.task.core.TaskJniUtils; import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider; + import org.tensorflow.lite.task.core.annotations.UsedByReflection; +import java.io.File; +import java.io.IOException; @@ -21092,7 +23150,7 @@ /** * Classifier API for NLClassification tasks with Bert models, categorizes string into different * classes. The API expects a Bert based TFLite model with metadata populated. -@@ -45,209 +48,199 @@ import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider; +@@ -45,209 +48,199 @@ import org.tensorflow.lite.task.core.annotations.UsedByReflection; * </ul> */ public class BertNLClassifier extends BaseTaskApi { @@ -21487,7 +23545,7 @@ + private native void deinitJni(long nativeHandle); } diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/NLClassifier.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/NLClassifier.java -index ff573bf415759..b8aa32be94dc5 100644 +index 19dcffca5e697..5c3eb2c9e3768 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/NLClassifier.java +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/NLClassifier.java @@ -17,13 +17,11 @@ package org.tensorflow.lite.task.text.nlclassifier; @@ -21504,12 +23562,12 @@ -import java.nio.MappedByteBuffer; -import java.util.List; + - import org.tensorflow.lite.annotations.UsedByReflection; import org.tensorflow.lite.support.label.Category; import org.tensorflow.lite.task.core.BaseOptions; -@@ -31,6 +29,12 @@ import org.tensorflow.lite.task.core.BaseTaskApi; - import org.tensorflow.lite.task.core.TaskJniUtils; + import org.tensorflow.lite.task.core.BaseTaskApi; +@@ -31,6 +29,12 @@ import org.tensorflow.lite.task.core.TaskJniUtils; import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider; + import org.tensorflow.lite.task.core.annotations.UsedByReflection; +import java.io.File; +import java.io.IOException; @@ -21520,7 +23578,7 @@ /** * Classifier API for natural language classification tasks, categorizes string into different * classes. -@@ -67,294 +71,296 @@ import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider; +@@ -67,294 +71,296 @@ import org.tensorflow.lite.task.core.annotations.UsedByReflection; * configurable for different TFLite models. */ public class NLClassifier extends BaseTaskApi { @@ -22531,10 +24589,10 @@ + private native void deinitJni(long nativeHandle); } diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/QaAnswer.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/QaAnswer.java -index 4259a69794059..955da9988ca0a 100644 +index b75a07e10cc7b..50917c035a995 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/QaAnswer.java +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/QaAnswer.java -@@ -22,37 +22,37 @@ import org.tensorflow.lite.annotations.UsedByReflection; +@@ -22,37 +22,37 @@ import org.tensorflow.lite.task.core.annotations.UsedByReflection; * position information to the context. */ public class QaAnswer { @@ -22630,8 +24688,432 @@ + */ + List<QaAnswer> answer(String context, String question); } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/searcher/TextSearcher.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/searcher/TextSearcher.java +index 1a32d10e47114..ea3b1b8c25b34 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/searcher/TextSearcher.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/searcher/TextSearcher.java +@@ -18,12 +18,9 @@ package org.tensorflow.lite.task.text.searcher; + import android.content.Context; + import android.content.res.AssetFileDescriptor; + import android.os.ParcelFileDescriptor; ++ + import com.google.auto.value.AutoValue; +-import java.io.File; +-import java.io.IOException; +-import java.nio.ByteBuffer; +-import java.nio.MappedByteBuffer; +-import java.util.List; ++ + import org.tensorflow.lite.task.core.BaseOptions; + import org.tensorflow.lite.task.core.BaseTaskApi; + import org.tensorflow.lite.task.core.TaskJniUtils; +@@ -31,6 +28,12 @@ import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider; + import org.tensorflow.lite.task.processor.NearestNeighbor; + import org.tensorflow.lite.task.processor.SearcherOptions; + ++import java.io.File; ++import java.io.IOException; ++import java.nio.ByteBuffer; ++import java.nio.MappedByteBuffer; ++import java.util.List; ++ + /** + * Performs similarity search on text string. + * +@@ -67,227 +70,193 @@ import org.tensorflow.lite.task.processor.SearcherOptions; + * the single file format (index file packed in the model) is supported. + */ + public final class TextSearcher extends BaseTaskApi { ++ private static final String TEXT_SEARCHER_NATIVE_LIB = "task_text_jni"; ++ private static final int OPTIONAL_FD_LENGTH = -1; ++ private static final int OPTIONAL_FD_OFFSET = -1; + +- private static final String TEXT_SEARCHER_NATIVE_LIB = "task_text_jni"; +- private static final int OPTIONAL_FD_LENGTH = -1; +- private static final int OPTIONAL_FD_OFFSET = -1; ++ /** ++ * Creates an {@link TextSearcher} instance from {@link TextSearcherOptions}. ++ * ++ * @param modelPath path of the search model with metadata in the assets ++ * @throws IOException if an I/O error occurs when loading the tflite model or the index file ++ * @throws IllegalArgumentException if an argument is invalid ++ * @throws IllegalStateException if there is an internal error ++ * @throws RuntimeException if there is an otherwise unspecified error ++ */ ++ public static TextSearcher createFromFileAndOptions(Context context, String modelPath, ++ final TextSearcherOptions options) throws IOException { ++ try (AssetFileDescriptor assetFileDescriptor = context.getAssets().openFd(modelPath)) { ++ return createFromModelFdAndOptions( ++ /*modelDescriptor=*/assetFileDescriptor.getParcelFileDescriptor().getFd(), ++ /*modelDescriptorLength=*/assetFileDescriptor.getLength(), ++ /*modelDescriptorOffset=*/assetFileDescriptor.getStartOffset(), options); ++ } ++ } + +- /** +- * Creates an {@link TextSearcher} instance from {@link TextSearcherOptions}. +- * +- * @param modelPath path of the search model with metadata in the assets +- * @throws IOException if an I/O error occurs when loading the tflite model or the index file +- * @throws IllegalArgumentException if an argument is invalid +- * @throws IllegalStateException if there is an internal error +- * @throws RuntimeException if there is an otherwise unspecified error +- */ +- public static TextSearcher createFromFileAndOptions( +- Context context, String modelPath, final TextSearcherOptions options) throws IOException { +- try (AssetFileDescriptor assetFileDescriptor = context.getAssets().openFd(modelPath)) { +- return createFromModelFdAndOptions( +- /*modelDescriptor=*/ assetFileDescriptor.getParcelFileDescriptor().getFd(), +- /*modelDescriptorLength=*/ assetFileDescriptor.getLength(), +- /*modelDescriptorOffset=*/ assetFileDescriptor.getStartOffset(), +- options); ++ /** ++ * Creates an {@link TextSearcher} instance. ++ * ++ * @param modelFile the search model {@link File} instance ++ * @throws IOException if an I/O error occurs when loading the tflite model or the index file ++ * @throws IllegalArgumentException if an argument is invalid ++ * @throws IllegalStateException if there is an internal error ++ * @throws RuntimeException if there is an otherwise unspecified error ++ */ ++ public static TextSearcher createFromFileAndOptions( ++ File modelFile, final TextSearcherOptions options) throws IOException { ++ try (ParcelFileDescriptor descriptor = ++ ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { ++ return createFromModelFdAndOptions( ++ /*modelDescriptor=*/descriptor.getFd(), ++ /*modelDescriptorLength=*/OPTIONAL_FD_LENGTH, ++ /*modelDescriptorOffset=*/OPTIONAL_FD_OFFSET, options); ++ } + } +- } + +- /** +- * Creates an {@link TextSearcher} instance. +- * +- * @param modelFile the search model {@link File} instance +- * @throws IOException if an I/O error occurs when loading the tflite model or the index file +- * @throws IllegalArgumentException if an argument is invalid +- * @throws IllegalStateException if there is an internal error +- * @throws RuntimeException if there is an otherwise unspecified error +- */ +- public static TextSearcher createFromFileAndOptions( +- File modelFile, final TextSearcherOptions options) throws IOException { +- try (ParcelFileDescriptor descriptor = +- ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { +- return createFromModelFdAndOptions( +- /*modelDescriptor=*/ descriptor.getFd(), +- /*modelDescriptorLength=*/ OPTIONAL_FD_LENGTH, +- /*modelDescriptorOffset=*/ OPTIONAL_FD_OFFSET, +- options); ++ /** ++ * Creates an {@link TextSearcher} instance with a model buffer and {@link TextSearcherOptions}. ++ * ++ * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the search ++ * model ++ * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a ++ * {@link MappedByteBuffer} ++ * @throws IOException if an I/O error occurs when loading the index file ++ * @throws IllegalStateException if there is an internal error ++ * @throws RuntimeException if there is an otherwise unspecified error ++ */ ++ public static TextSearcher createFromBufferAndOptions( ++ final ByteBuffer modelBuffer, final TextSearcherOptions options) throws IOException { ++ if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) { ++ throw new IllegalArgumentException( ++ "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer."); ++ } ++ if (options.getSearcherOptions().getIndexFile() != null) { ++ try (ParcelFileDescriptor indexDescriptor = ++ ParcelFileDescriptor.open(options.getSearcherOptions().getIndexFile(), ++ ParcelFileDescriptor.MODE_READ_ONLY)) { ++ return createFromBufferAndOptionsImpl( ++ modelBuffer, options, indexDescriptor.getFd()); ++ } ++ } else { ++ return createFromBufferAndOptionsImpl(modelBuffer, options, /*indexFd=*/0); ++ } + } +- } + +- /** +- * Creates an {@link TextSearcher} instance with a model buffer and {@link TextSearcherOptions}. +- * +- * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the search +- * model +- * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a +- * {@link MappedByteBuffer} +- * @throws IOException if an I/O error occurs when loading the index file +- * @throws IllegalStateException if there is an internal error +- * @throws RuntimeException if there is an otherwise unspecified error +- */ +- public static TextSearcher createFromBufferAndOptions( +- final ByteBuffer modelBuffer, final TextSearcherOptions options) throws IOException { +- if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) { +- throw new IllegalArgumentException( +- "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer."); ++ public static TextSearcher createFromBufferAndOptionsImpl( ++ final ByteBuffer modelBuffer, final TextSearcherOptions options, final int indexFd) { ++ return new TextSearcher(TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() { ++ @Override ++ public long createHandle() { ++ return initJniWithByteBuffer(modelBuffer, ++ TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()), ++ options.getSearcherOptions().getL2Normalize(), ++ options.getSearcherOptions().getQuantize(), indexFd, ++ options.getSearcherOptions().getMaxResults()); ++ } ++ }, TEXT_SEARCHER_NATIVE_LIB)); + } +- if (options.getSearcherOptions().getIndexFile() != null) { +- try (ParcelFileDescriptor indexDescriptor = +- ParcelFileDescriptor.open( +- options.getSearcherOptions().getIndexFile(), ParcelFileDescriptor.MODE_READ_ONLY)) { +- return createFromBufferAndOptionsImpl(modelBuffer, options, indexDescriptor.getFd()); +- } +- } else { +- return createFromBufferAndOptionsImpl(modelBuffer, options, /*indexFd=*/ 0); ++ ++ /** ++ * Constructor to initialize the JNI with a pointer from C++. ++ * ++ * @param nativeHandle a pointer referencing memory allocated in C++ ++ */ ++ TextSearcher(long nativeHandle) { ++ super(nativeHandle); + } +- } + +- public static TextSearcher createFromBufferAndOptionsImpl( +- final ByteBuffer modelBuffer, final TextSearcherOptions options, final int indexFd) { +- return new TextSearcher( +- TaskJniUtils.createHandleFromLibrary( +- new EmptyHandleProvider() { +- @Override +- public long createHandle() { +- return initJniWithByteBuffer( +- modelBuffer, +- TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()), +- options.getSearcherOptions().getL2Normalize(), +- options.getSearcherOptions().getQuantize(), +- indexFd, +- options.getSearcherOptions().getMaxResults()); +- } +- }, +- TEXT_SEARCHER_NATIVE_LIB)); +- } ++ /** Options for setting up an TextSearcher. */ ++ @AutoValue ++ public abstract static class TextSearcherOptions { ++ abstract BaseOptions getBaseOptions(); + +- /** +- * Constructor to initialize the JNI with a pointer from C++. +- * +- * @param nativeHandle a pointer referencing memory allocated in C++ +- */ +- TextSearcher(long nativeHandle) { +- super(nativeHandle); +- } ++ abstract SearcherOptions getSearcherOptions(); + +- /** Options for setting up an TextSearcher. */ +- @AutoValue +- public abstract static class TextSearcherOptions { ++ public static Builder builder() { ++ return new AutoValue_TextSearcher_TextSearcherOptions.Builder() ++ .setBaseOptions(BaseOptions.builder().build()) ++ .setSearcherOptions(SearcherOptions.builder().build()); ++ } + +- abstract BaseOptions getBaseOptions(); ++ /** Builder for {@link TextSearcherOptions}. */ ++ @AutoValue.Builder ++ public abstract static class Builder { ++ /** Sets the general options to configure Task APIs, such as accelerators. */ ++ public abstract Builder setBaseOptions(BaseOptions baseOptions); + +- abstract SearcherOptions getSearcherOptions(); ++ /** Sets the options to configure Searcher API. */ ++ public abstract Builder setSearcherOptions(SearcherOptions searcherOptions); + +- public static Builder builder() { +- return new AutoValue_TextSearcher_TextSearcherOptions.Builder() +- .setBaseOptions(BaseOptions.builder().build()) +- .setSearcherOptions(SearcherOptions.builder().build()); ++ public abstract TextSearcherOptions build(); ++ } + } + +- /** Builder for {@link TextSearcherOptions}. */ +- @AutoValue.Builder +- public abstract static class Builder { +- /** Sets the general options to configure Task APIs, such as accelerators. */ +- public abstract Builder setBaseOptions(BaseOptions baseOptions); +- +- /** Sets the options to configure Searcher API. */ +- public abstract Builder setSearcherOptions(SearcherOptions searcherOptions); +- +- public abstract TextSearcherOptions build(); ++ /** ++ * Performs embedding extraction on the provided string input, followed by nearest-neighbor ++ * search in the index. ++ * ++ * @param text input text query to the model ++ */ ++ public List<NearestNeighbor> search(String text) { ++ return searchNative(getNativeHandle(), text); + } +- } +- +- /** +- * Performs embedding extraction on the provided string input, followed by nearest-neighbor search +- * in the index. +- * +- * @param text input text query to the model +- */ +- public List<NearestNeighbor> search(String text) { +- return searchNative(getNativeHandle(), text); +- } + +- private static TextSearcher createFromModelFdAndOptions( +- final int modelDescriptor, +- final long modelDescriptorLength, +- final long modelDescriptorOffset, +- final TextSearcherOptions options) +- throws IOException { +- if (options.getSearcherOptions().getIndexFile() != null) { +- // indexDescriptor must be alive before TextSearcher is initialized completely in the native +- // layer. +- try (ParcelFileDescriptor indexDescriptor = +- ParcelFileDescriptor.open( +- options.getSearcherOptions().getIndexFile(), ParcelFileDescriptor.MODE_READ_ONLY)) { +- return createFromModelFdAndOptionsImpl( +- modelDescriptor, +- modelDescriptorLength, +- modelDescriptorOffset, +- options, +- indexDescriptor.getFd()); +- } +- } else { +- // Index file is not configured. We'll check if the model contains one in the native layer. +- return createFromModelFdAndOptionsImpl( +- modelDescriptor, modelDescriptorLength, modelDescriptorOffset, options, /*indexFd=*/ 0); ++ private static TextSearcher createFromModelFdAndOptions(final int modelDescriptor, ++ final long modelDescriptorLength, final long modelDescriptorOffset, ++ final TextSearcherOptions options) throws IOException { ++ if (options.getSearcherOptions().getIndexFile() != null) { ++ // indexDescriptor must be alive before TextSearcher is initialized completely in the ++ // native layer. ++ try (ParcelFileDescriptor indexDescriptor = ++ ParcelFileDescriptor.open(options.getSearcherOptions().getIndexFile(), ++ ParcelFileDescriptor.MODE_READ_ONLY)) { ++ return createFromModelFdAndOptionsImpl(modelDescriptor, modelDescriptorLength, ++ modelDescriptorOffset, options, indexDescriptor.getFd()); ++ } ++ } else { ++ // Index file is not configured. We'll check if the model contains one in the native ++ // layer. ++ return createFromModelFdAndOptionsImpl(modelDescriptor, modelDescriptorLength, ++ modelDescriptorOffset, options, /*indexFd=*/0); ++ } + } +- } + +- private static TextSearcher createFromModelFdAndOptionsImpl( +- final int modelDescriptor, +- final long modelDescriptorLength, +- final long modelDescriptorOffset, +- final TextSearcherOptions options, +- final int indexFd) { +- long nativeHandle = +- TaskJniUtils.createHandleFromLibrary( +- new EmptyHandleProvider() { +- @Override +- public long createHandle() { +- return initJniWithModelFdAndOptions( +- modelDescriptor, +- modelDescriptorLength, +- modelDescriptorOffset, +- TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()), +- options.getSearcherOptions().getL2Normalize(), +- options.getSearcherOptions().getQuantize(), +- indexFd, +- options.getSearcherOptions().getMaxResults()); +- } +- }, +- TEXT_SEARCHER_NATIVE_LIB); +- return new TextSearcher(nativeHandle); +- } ++ private static TextSearcher createFromModelFdAndOptionsImpl(final int modelDescriptor, ++ final long modelDescriptorLength, final long modelDescriptorOffset, ++ final TextSearcherOptions options, final int indexFd) { ++ long nativeHandle = TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() { ++ @Override ++ public long createHandle() { ++ return initJniWithModelFdAndOptions(modelDescriptor, modelDescriptorLength, ++ modelDescriptorOffset, ++ TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()), ++ options.getSearcherOptions().getL2Normalize(), ++ options.getSearcherOptions().getQuantize(), indexFd, ++ options.getSearcherOptions().getMaxResults()); ++ } ++ }, TEXT_SEARCHER_NATIVE_LIB); ++ return new TextSearcher(nativeHandle); ++ } + +- private static native long initJniWithModelFdAndOptions( +- int modelDescriptor, +- long modelDescriptorLength, +- long modelDescriptorOffset, +- long baseOptionsHandle, +- boolean l2Normalize, +- boolean quantize, +- int indexDescriptor, +- int maxResults); ++ private static native long initJniWithModelFdAndOptions(int modelDescriptor, ++ long modelDescriptorLength, long modelDescriptorOffset, long baseOptionsHandle, ++ boolean l2Normalize, boolean quantize, int indexDescriptor, int maxResults); + +- private static native long initJniWithByteBuffer( +- ByteBuffer modelBuffer, +- long baseOptionsHandle, +- boolean l2Normalize, +- boolean quantize, +- int indexFileDescriptor, +- int maxResults); ++ private static native long initJniWithByteBuffer(ByteBuffer modelBuffer, long baseOptionsHandle, ++ boolean l2Normalize, boolean quantize, int indexFileDescriptor, int maxResults); + +- /** The native method to search an input text string. */ +- private static native List<NearestNeighbor> searchNative(long nativeHandle, String text); ++ /** The native method to search an input text string. */ ++ private static native List<NearestNeighbor> searchNative(long nativeHandle, String text); + +- @Override +- protected void deinit(long nativeHandle) { +- deinitJni(nativeHandle); +- } ++ @Override ++ protected void deinit(long nativeHandle) { ++ deinitJni(nativeHandle); ++ } + +- /** +- * Native implementation to release memory pointed by the pointer. +- * +- * @param nativeHandle pointer to memory allocated +- */ +- private native void deinitJni(long nativeHandle); ++ /** ++ * Native implementation to release memory pointed by the pointer. ++ * ++ * @param nativeHandle pointer to memory allocated ++ */ ++ private native void deinitJni(long nativeHandle); + } diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/Classifications.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/Classifications.java -index d33f0fbbdd497..0d35443a7de5d 100644 +index 88aeecc8d62ca..e59a2e89e86f4 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/Classifications.java +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/Classifications.java @@ -16,11 +16,13 @@ limitations under the License. @@ -22639,18 +25121,18 @@ import com.google.auto.value.AutoValue; + -+import org.tensorflow.lite.annotations.UsedByReflection; +import org.tensorflow.lite.support.label.Category; ++import org.tensorflow.lite.task.core.annotations.UsedByReflection; + import java.util.ArrayList; import java.util.Collections; import java.util.List; --import org.tensorflow.lite.annotations.UsedByReflection; -import org.tensorflow.lite.support.label.Category; +-import org.tensorflow.lite.task.core.annotations.UsedByReflection; /** * The classification results of one head in a multihead (a.k.a. multi-output) {@link -@@ -31,16 +33,15 @@ import org.tensorflow.lite.support.label.Category; +@@ -31,16 +33,15 @@ import org.tensorflow.lite.task.core.annotations.UsedByReflection; @AutoValue @UsedByReflection("image_classifier_jni.cc") public abstract class Classifications { @@ -22677,7 +25159,7 @@ + public abstract int getHeadIndex(); } diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/ImageClassifier.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/ImageClassifier.java -index 2bf3fa8a465b4..48038f6a1c04e 100644 +index 90628928198d5..5b5be73bcca1e 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/ImageClassifier.java +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/ImageClassifier.java @@ -18,14 +18,9 @@ package org.tensorflow.lite.task.vision.classifier; @@ -22694,9 +25176,9 @@ -import java.util.Collections; -import java.util.List; + - import org.tensorflow.lite.annotations.UsedByReflection; import org.tensorflow.lite.support.image.MlImageAdapter; import org.tensorflow.lite.support.image.TensorImage; + import org.tensorflow.lite.task.core.BaseOptions; @@ -37,6 +32,14 @@ import org.tensorflow.lite.task.core.vision.ImageProcessingOptions; import org.tensorflow.lite.task.vision.core.BaseVisionTaskApi; import org.tensorflow.lite.task.vision.core.BaseVisionTaskApi.InferenceProvider; @@ -23987,7 +26469,7 @@ - } } diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/Detection.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/Detection.java -index 007e032d8b331..7106fe8a08b35 100644 +index 859e41fc038be..096af521c6b00 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/Detection.java +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/Detection.java @@ -16,27 +16,29 @@ limitations under the License. @@ -23997,14 +26479,14 @@ + import com.google.auto.value.AutoValue; + -+import org.tensorflow.lite.annotations.UsedByReflection; +import org.tensorflow.lite.support.label.Category; ++import org.tensorflow.lite.task.core.annotations.UsedByReflection; + import java.util.ArrayList; import java.util.Collections; import java.util.List; --import org.tensorflow.lite.annotations.UsedByReflection; -import org.tensorflow.lite.support.label.Category; +-import org.tensorflow.lite.task.core.annotations.UsedByReflection; /** Represents one detected object in the results of a {@link ObjectDetector}. */ @AutoValue @@ -24033,7 +26515,7 @@ + public abstract List<Category> getCategories(); } diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/ObjectDetector.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/ObjectDetector.java -index e2046d15a7351..c0585b8eda6aa 100644 +index 4aff7bfab8ca5..d1fb421fc0bbf 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/ObjectDetector.java +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/ObjectDetector.java @@ -17,14 +17,9 @@ package org.tensorflow.lite.task.vision.detector; @@ -24050,10 +26532,10 @@ -import java.util.Collections; -import java.util.List; + - import org.tensorflow.lite.annotations.UsedByReflection; import org.tensorflow.lite.support.image.MlImageAdapter; import org.tensorflow.lite.support.image.TensorImage; -@@ -35,6 +30,14 @@ import org.tensorflow.lite.task.core.TaskJniUtils.FdAndOptionsHandleProvider; + import org.tensorflow.lite.task.core.BaseOptions; +@@ -35,6 +30,14 @@ import org.tensorflow.lite.task.core.annotations.UsedByReflection; import org.tensorflow.lite.task.core.vision.ImageProcessingOptions; import org.tensorflow.lite.task.vision.core.BaseVisionTaskApi; @@ -24955,8 +27437,649 @@ + */ + private native void deinitJni(long nativeHandle); } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/searcher/ImageSearcher.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/searcher/ImageSearcher.java +index 7a02ad8a037a2..d3d1e6a4f4878 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/searcher/ImageSearcher.java ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/searcher/ImageSearcher.java +@@ -19,13 +19,10 @@ import android.content.Context; + import android.content.res.AssetFileDescriptor; + import android.graphics.Rect; + import android.os.ParcelFileDescriptor; ++ + import com.google.android.odml.image.MlImage; + import com.google.auto.value.AutoValue; +-import java.io.File; +-import java.io.IOException; +-import java.nio.ByteBuffer; +-import java.nio.MappedByteBuffer; +-import java.util.List; ++ + import org.tensorflow.lite.support.image.MlImageAdapter; + import org.tensorflow.lite.support.image.TensorImage; + import org.tensorflow.lite.task.core.BaseOptions; +@@ -37,6 +34,12 @@ import org.tensorflow.lite.task.processor.SearcherOptions; + import org.tensorflow.lite.task.vision.core.BaseVisionTaskApi; + import org.tensorflow.lite.task.vision.core.BaseVisionTaskApi.InferenceProvider; + ++import java.io.File; ++import java.io.IOException; ++import java.nio.ByteBuffer; ++import java.nio.MappedByteBuffer; ++import java.util.List; ++ + /** + * Performs similarity search on images. + * +@@ -66,330 +69,292 @@ import org.tensorflow.lite.task.vision.core.BaseVisionTaskApi.InferenceProvider; + * the single file format (index file packed in the model) is supported. + */ + public final class ImageSearcher extends BaseVisionTaskApi { ++ private static final String IMAGE_SEARCHER_NATIVE_LIB = "task_vision_jni"; ++ private static final int OPTIONAL_FD_LENGTH = -1; ++ private static final int OPTIONAL_FD_OFFSET = -1; + +- private static final String IMAGE_SEARCHER_NATIVE_LIB = "task_vision_jni"; +- private static final int OPTIONAL_FD_LENGTH = -1; +- private static final int OPTIONAL_FD_OFFSET = -1; +- +- /** +- * Creates an {@link ImageSearcher} instance from {@link ImageSearcherOptions}. +- * +- * @param modelPath path of the search model with metadata in the assets +- * @throws IOException if an I/O error occurs when loading the tflite model or the index file +- * @throws IllegalArgumentException if an argument is invalid +- * @throws IllegalStateException if there is an internal error +- * @throws RuntimeException if there is an otherwise unspecified error +- */ +- public static ImageSearcher createFromFileAndOptions( +- Context context, String modelPath, final ImageSearcherOptions options) throws IOException { +- try (AssetFileDescriptor assetFileDescriptor = context.getAssets().openFd(modelPath)) { +- return createFromModelFdAndOptions( +- /*modelDescriptor=*/ assetFileDescriptor.getParcelFileDescriptor().getFd(), +- /*modelDescriptorLength=*/ assetFileDescriptor.getLength(), +- /*modelDescriptorOffset=*/ assetFileDescriptor.getStartOffset(), +- options); ++ /** ++ * Creates an {@link ImageSearcher} instance from {@link ImageSearcherOptions}. ++ * ++ * @param modelPath path of the search model with metadata in the assets ++ * @throws IOException if an I/O error occurs when loading the tflite model or the index file ++ * @throws IllegalArgumentException if an argument is invalid ++ * @throws IllegalStateException if there is an internal error ++ * @throws RuntimeException if there is an otherwise unspecified error ++ */ ++ public static ImageSearcher createFromFileAndOptions(Context context, String modelPath, ++ final ImageSearcherOptions options) throws IOException { ++ try (AssetFileDescriptor assetFileDescriptor = context.getAssets().openFd(modelPath)) { ++ return createFromModelFdAndOptions( ++ /*modelDescriptor=*/assetFileDescriptor.getParcelFileDescriptor().getFd(), ++ /*modelDescriptorLength=*/assetFileDescriptor.getLength(), ++ /*modelDescriptorOffset=*/assetFileDescriptor.getStartOffset(), options); ++ } + } +- } +- +- /** +- * Creates an {@link ImageSearcher} instance. +- * +- * @param modelFile the search model {@link File} instance +- * @throws IOException if an I/O error occurs when loading the tflite model or the index file +- * @throws IllegalArgumentException if an argument is invalid +- * @throws IllegalStateException if there is an internal error +- * @throws RuntimeException if there is an otherwise unspecified error +- */ +- public static ImageSearcher createFromFileAndOptions( +- File modelFile, final ImageSearcherOptions options) throws IOException { +- try (ParcelFileDescriptor descriptor = +- ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { +- return createFromModelFdAndOptions( +- /*modelDescriptor=*/ descriptor.getFd(), +- /*modelDescriptorLength=*/ OPTIONAL_FD_LENGTH, +- /*modelDescriptorOffset=*/ OPTIONAL_FD_OFFSET, +- options); ++ ++ /** ++ * Creates an {@link ImageSearcher} instance. ++ * ++ * @param modelFile the search model {@link File} instance ++ * @throws IOException if an I/O error occurs when loading the tflite model or the index file ++ * @throws IllegalArgumentException if an argument is invalid ++ * @throws IllegalStateException if there is an internal error ++ * @throws RuntimeException if there is an otherwise unspecified error ++ */ ++ public static ImageSearcher createFromFileAndOptions( ++ File modelFile, final ImageSearcherOptions options) throws IOException { ++ try (ParcelFileDescriptor descriptor = ++ ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { ++ return createFromModelFdAndOptions( ++ /*modelDescriptor=*/descriptor.getFd(), ++ /*modelDescriptorLength=*/OPTIONAL_FD_LENGTH, ++ /*modelDescriptorOffset=*/OPTIONAL_FD_OFFSET, options); ++ } + } +- } +- +- /** +- * Creates an {@link ImageSearcher} instance with a model buffer and {@link ImageSearcherOptions}. +- * +- * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the search +- * model +- * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a +- * {@link MappedByteBuffer} +- * @throws IOException if an I/O error occurs when loading the index file +- * @throws IllegalStateException if there is an internal error +- * @throws RuntimeException if there is an otherwise unspecified error +- */ +- public static ImageSearcher createFromBufferAndOptions( +- final ByteBuffer modelBuffer, final ImageSearcherOptions options) throws IOException { +- if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) { +- throw new IllegalArgumentException( +- "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer."); ++ ++ /** ++ * Creates an {@link ImageSearcher} instance with a model buffer and {@link ++ * ImageSearcherOptions}. ++ * ++ * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the search ++ * model ++ * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a ++ * {@link MappedByteBuffer} ++ * @throws IOException if an I/O error occurs when loading the index file ++ * @throws IllegalStateException if there is an internal error ++ * @throws RuntimeException if there is an otherwise unspecified error ++ */ ++ public static ImageSearcher createFromBufferAndOptions( ++ final ByteBuffer modelBuffer, final ImageSearcherOptions options) throws IOException { ++ if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) { ++ throw new IllegalArgumentException( ++ "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer."); ++ } ++ if (options.getSearcherOptions().getIndexFile() != null) { ++ try (ParcelFileDescriptor indexDescriptor = ++ ParcelFileDescriptor.open(options.getSearcherOptions().getIndexFile(), ++ ParcelFileDescriptor.MODE_READ_ONLY)) { ++ return createFromBufferAndOptionsImpl( ++ modelBuffer, options, indexDescriptor.getFd()); ++ } ++ } else { ++ return createFromBufferAndOptionsImpl(modelBuffer, options, /*indexFd=*/0); ++ } + } +- if (options.getSearcherOptions().getIndexFile() != null) { +- try (ParcelFileDescriptor indexDescriptor = +- ParcelFileDescriptor.open( +- options.getSearcherOptions().getIndexFile(), ParcelFileDescriptor.MODE_READ_ONLY)) { +- return createFromBufferAndOptionsImpl(modelBuffer, options, indexDescriptor.getFd()); +- } +- } else { +- return createFromBufferAndOptionsImpl(modelBuffer, options, /*indexFd=*/ 0); ++ ++ public static ImageSearcher createFromBufferAndOptionsImpl( ++ final ByteBuffer modelBuffer, final ImageSearcherOptions options, final int indexFd) { ++ return new ImageSearcher(TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() { ++ @Override ++ public long createHandle() { ++ return initJniWithByteBuffer(modelBuffer, ++ TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()), ++ options.getSearcherOptions().getL2Normalize(), ++ options.getSearcherOptions().getQuantize(), indexFd, ++ options.getSearcherOptions().getMaxResults()); ++ } ++ }, IMAGE_SEARCHER_NATIVE_LIB)); + } +- } +- +- public static ImageSearcher createFromBufferAndOptionsImpl( +- final ByteBuffer modelBuffer, final ImageSearcherOptions options, final int indexFd) { +- return new ImageSearcher( +- TaskJniUtils.createHandleFromLibrary( +- new EmptyHandleProvider() { +- @Override +- public long createHandle() { +- return initJniWithByteBuffer( +- modelBuffer, +- TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()), +- options.getSearcherOptions().getL2Normalize(), +- options.getSearcherOptions().getQuantize(), +- indexFd, +- options.getSearcherOptions().getMaxResults()); +- } +- }, +- IMAGE_SEARCHER_NATIVE_LIB)); +- } +- +- /** +- * Constructor to initialize the JNI with a pointer from C++. +- * +- * @param nativeHandle a pointer referencing memory allocated in C++ +- */ +- ImageSearcher(long nativeHandle) { +- super(nativeHandle); +- } +- +- /** Options for setting up an ImageSearcher. */ +- @AutoValue +- public abstract static class ImageSearcherOptions { +- +- abstract BaseOptions getBaseOptions(); +- +- abstract SearcherOptions getSearcherOptions(); +- +- public static Builder builder() { +- return new AutoValue_ImageSearcher_ImageSearcherOptions.Builder() +- .setBaseOptions(BaseOptions.builder().build()) +- .setSearcherOptions(SearcherOptions.builder().build()); ++ ++ /** ++ * Constructor to initialize the JNI with a pointer from C++. ++ * ++ * @param nativeHandle a pointer referencing memory allocated in C++ ++ */ ++ ImageSearcher(long nativeHandle) { ++ super(nativeHandle); + } + +- /** Builder for {@link ImageSearcherOptions}. */ +- @AutoValue.Builder +- public abstract static class Builder { +- /** Sets the general options to configure Task APIs, such as accelerators. */ +- public abstract Builder setBaseOptions(BaseOptions baseOptions); ++ /** Options for setting up an ImageSearcher. */ ++ @AutoValue ++ public abstract static class ImageSearcherOptions { ++ abstract BaseOptions getBaseOptions(); ++ ++ abstract SearcherOptions getSearcherOptions(); ++ ++ public static Builder builder() { ++ return new AutoValue_ImageSearcher_ImageSearcherOptions.Builder() ++ .setBaseOptions(BaseOptions.builder().build()) ++ .setSearcherOptions(SearcherOptions.builder().build()); ++ } ++ ++ /** Builder for {@link ImageSearcherOptions}. */ ++ @AutoValue.Builder ++ public abstract static class Builder { ++ /** Sets the general options to configure Task APIs, such as accelerators. */ ++ public abstract Builder setBaseOptions(BaseOptions baseOptions); + +- /** Sets the options to configure Searcher API. */ +- public abstract Builder setSearcherOptions(SearcherOptions searcherOptions); ++ /** Sets the options to configure Searcher API. */ ++ public abstract Builder setSearcherOptions(SearcherOptions searcherOptions); + +- public abstract ImageSearcherOptions build(); ++ public abstract ImageSearcherOptions build(); ++ } + } +- } +- +- /** +- * Performs embedding extraction on the provided {@link TensorImage}, followed by nearest-neighbor +- * search in the index. +- * +- * <p>{@link ImageSearcher} supports the following {@link TensorImage} color space types: +- * +- * <ul> +- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB} +- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12} +- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21} +- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12} +- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21} +- * </ul> +- * +- * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image +- * @throws IllegalArgumentException if the color space type of image is unsupported +- */ +- public List<NearestNeighbor> search(TensorImage image) { +- return search(image, ImageProcessingOptions.builder().build()); +- } +- +- /** +- * Performs embedding extraction on the provided {@link TensorImage} with {@link +- * ImageProcessingOptions}, followed by nearest-neighbor search in the index. +- * +- * <p>{@link ImageSearcher} supports the following options: +- * +- * <ul> +- * <li>Region of interest (ROI) (through {@link ImageProcessingOptions.Builder#setRoi}). It +- * defaults to the entire image. +- * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It +- * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}. +- * </ul> +- * +- * <p>{@link ImageSearcher} supports the following {@link TensorImage} color space types: +- * +- * <ul> +- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB} +- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12} +- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21} +- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12} +- * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21} +- * </ul> +- * +- * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image +- * @throws IllegalArgumentException if the color space type of image is unsupported +- */ +- public List<NearestNeighbor> search(TensorImage image, ImageProcessingOptions options) { +- return run( +- new InferenceProvider<List<NearestNeighbor>>() { +- @Override +- public List<NearestNeighbor> run( +- long frameBufferHandle, int width, int height, ImageProcessingOptions options) { +- return search(frameBufferHandle, width, height, options); +- } +- }, +- image, +- options); +- } +- +- /** +- * Performs embedding extraction on the provided {@code MlImage}, followed by nearest-neighbor +- * search in the index. +- * +- * @param image an {@code MlImage} object that represents an image +- * @throws IllegalArgumentException if the storage type or format of the image is unsupported +- */ +- public List<NearestNeighbor> search(MlImage image) { +- return search(image, ImageProcessingOptions.builder().build()); +- } +- +- /** +- * Performs embedding extraction on the provided {@code MlImage} with {@link +- * ImageProcessingOptions}, followed by nearest-neighbor search in the index. +- * +- * <p>{@link ImageSearcher} supports the following options: +- * +- * <ul> +- * <li>Region of interest (ROI) (through {@link ImageProcessingOptions.Builder#setRoi}). It +- * defaults to the entire image. +- * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It +- * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}. {@link +- * MlImage#getRotation()} is not effective. +- * </ul> +- * +- * @param image a {@code MlImage} object that represents an image +- * @param options configures options including ROI and rotation +- * @throws IllegalArgumentException if the storage type or format of the image is unsupported +- */ +- public List<NearestNeighbor> search(MlImage image, ImageProcessingOptions options) { +- image.getInternal().acquire(); +- TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image); +- List<NearestNeighbor> result = search(tensorImage, options); +- image.close(); +- return result; +- } +- +- private List<NearestNeighbor> search( +- long frameBufferHandle, int width, int height, ImageProcessingOptions options) { +- checkNotClosed(); +- Rect roi = options.getRoi().isEmpty() ? new Rect(0, 0, width, height) : options.getRoi(); +- return searchNative( +- getNativeHandle(), +- frameBufferHandle, +- new int[] {roi.left, roi.top, roi.width(), roi.height()}); +- } +- +- private static ImageSearcher createFromModelFdAndOptions( +- final int modelDescriptor, +- final long modelDescriptorLength, +- final long modelDescriptorOffset, +- final ImageSearcherOptions options) +- throws IOException { +- if (options.getSearcherOptions().getIndexFile() != null) { +- // indexDescriptor must be alive before ImageSearcher is initialized completely in the native +- // layer. +- try (ParcelFileDescriptor indexDescriptor = +- ParcelFileDescriptor.open( +- options.getSearcherOptions().getIndexFile(), ParcelFileDescriptor.MODE_READ_ONLY)) { +- return createFromModelFdAndOptionsImpl( +- modelDescriptor, +- modelDescriptorLength, +- modelDescriptorOffset, +- options, +- indexDescriptor.getFd()); +- } +- } else { +- // Index file is not configured. We'll check if the model contains one in the native layer. +- return createFromModelFdAndOptionsImpl( +- modelDescriptor, modelDescriptorLength, modelDescriptorOffset, options, /*indexFd=*/ 0); ++ ++ /** ++ * Performs embedding extraction on the provided {@link TensorImage}, followed by ++ * nearest-neighbor search in the index. ++ * ++ * <p>{@link ImageSearcher} supports the following {@link TensorImage} color space types: ++ * ++ * <ul> ++ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB} ++ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12} ++ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21} ++ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12} ++ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21} ++ * </ul> ++ * ++ * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image ++ * @throws IllegalArgumentException if the color space type of image is unsupported ++ */ ++ public List<NearestNeighbor> search(TensorImage image) { ++ return search(image, ImageProcessingOptions.builder().build()); ++ } ++ ++ /** ++ * Performs embedding extraction on the provided {@link TensorImage} with {@link ++ * ImageProcessingOptions}, followed by nearest-neighbor search in the index. ++ * ++ * <p>{@link ImageSearcher} supports the following options: ++ * ++ * <ul> ++ * <li>Region of interest (ROI) (through {@link ImageProcessingOptions.Builder#setRoi}). It ++ * defaults to the entire image. ++ * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It ++ * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}. ++ * </ul> ++ * ++ * <p>{@link ImageSearcher} supports the following {@link TensorImage} color space types: ++ * ++ * <ul> ++ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB} ++ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12} ++ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21} ++ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12} ++ * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21} ++ * </ul> ++ * ++ * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image ++ * @throws IllegalArgumentException if the color space type of image is unsupported ++ */ ++ public List<NearestNeighbor> search(TensorImage image, ImageProcessingOptions options) { ++ return run(new InferenceProvider<List<NearestNeighbor>>() { ++ @Override ++ public List<NearestNeighbor> run( ++ long frameBufferHandle, int width, int height, ImageProcessingOptions options) { ++ return search(frameBufferHandle, width, height, options); ++ } ++ }, image, options); ++ } ++ ++ /** ++ * Performs embedding extraction on the provided {@code MlImage}, followed by nearest-neighbor ++ * search in the index. ++ * ++ * @param image an {@code MlImage} object that represents an image ++ * @throws IllegalArgumentException if the storage type or format of the image is unsupported ++ */ ++ public List<NearestNeighbor> search(MlImage image) { ++ return search(image, ImageProcessingOptions.builder().build()); + } +- } +- +- private static ImageSearcher createFromModelFdAndOptionsImpl( +- final int modelDescriptor, +- final long modelDescriptorLength, +- final long modelDescriptorOffset, +- final ImageSearcherOptions options, +- final int indexFd) { +- long nativeHandle = +- TaskJniUtils.createHandleFromLibrary( +- new EmptyHandleProvider() { +- @Override +- public long createHandle() { +- return initJniWithModelFdAndOptions( +- modelDescriptor, +- modelDescriptorLength, +- modelDescriptorOffset, +- TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()), +- options.getSearcherOptions().getL2Normalize(), +- options.getSearcherOptions().getQuantize(), +- indexFd, +- options.getSearcherOptions().getMaxResults()); +- } +- }, +- IMAGE_SEARCHER_NATIVE_LIB); +- return new ImageSearcher(nativeHandle); +- } +- +- private static native long initJniWithModelFdAndOptions( +- int modelDescriptor, +- long modelDescriptorLength, +- long modelDescriptorOffset, +- long baseOptionsHandle, +- boolean l2Normalize, +- boolean quantize, +- int indexDescriptor, +- int maxResults); +- +- private static native long initJniWithByteBuffer( +- ByteBuffer modelBuffer, +- long baseOptionsHandle, +- boolean l2Normalize, +- boolean quantize, +- int indexFileDescriptor, +- int maxResults); +- +- /** +- * The native method to search an image based on the ROI specified. +- * +- * @param roi the ROI of the input image, an array representing the bounding box as {left, top, +- * width, height} +- */ +- private static native List<NearestNeighbor> searchNative( +- long nativeHandle, long frameBufferHandle, int[] roi); +- +- @Override +- protected void deinit(long nativeHandle) { +- deinitJni(nativeHandle); +- } +- +- /** +- * Native implementation to release memory pointed by the pointer. +- * +- * @param nativeHandle pointer to memory allocated +- */ +- private native void deinitJni(long nativeHandle); ++ ++ /** ++ * Performs embedding extraction on the provided {@code MlImage} with {@link ++ * ImageProcessingOptions}, followed by nearest-neighbor search in the index. ++ * ++ * <p>{@link ImageSearcher} supports the following options: ++ * ++ * <ul> ++ * <li>Region of interest (ROI) (through {@link ImageProcessingOptions.Builder#setRoi}). It ++ * defaults to the entire image. ++ * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It ++ * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}. {@link ++ * MlImage#getRotation()} is not effective. ++ * </ul> ++ * ++ * @param image a {@code MlImage} object that represents an image ++ * @param options configures options including ROI and rotation ++ * @throws IllegalArgumentException if the storage type or format of the image is unsupported ++ */ ++ public List<NearestNeighbor> search(MlImage image, ImageProcessingOptions options) { ++ image.getInternal().acquire(); ++ TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image); ++ List<NearestNeighbor> result = search(tensorImage, options); ++ image.close(); ++ return result; ++ } ++ ++ private List<NearestNeighbor> search( ++ long frameBufferHandle, int width, int height, ImageProcessingOptions options) { ++ checkNotClosed(); ++ Rect roi = options.getRoi().isEmpty() ? new Rect(0, 0, width, height) : options.getRoi(); ++ return searchNative(getNativeHandle(), frameBufferHandle, ++ new int[] {roi.left, roi.top, roi.width(), roi.height()}); ++ } ++ ++ private static ImageSearcher createFromModelFdAndOptions(final int modelDescriptor, ++ final long modelDescriptorLength, final long modelDescriptorOffset, ++ final ImageSearcherOptions options) throws IOException { ++ if (options.getSearcherOptions().getIndexFile() != null) { ++ // indexDescriptor must be alive before ImageSearcher is initialized completely in the ++ // native layer. ++ try (ParcelFileDescriptor indexDescriptor = ++ ParcelFileDescriptor.open(options.getSearcherOptions().getIndexFile(), ++ ParcelFileDescriptor.MODE_READ_ONLY)) { ++ return createFromModelFdAndOptionsImpl(modelDescriptor, modelDescriptorLength, ++ modelDescriptorOffset, options, indexDescriptor.getFd()); ++ } ++ } else { ++ // Index file is not configured. We'll check if the model contains one in the native ++ // layer. ++ return createFromModelFdAndOptionsImpl(modelDescriptor, modelDescriptorLength, ++ modelDescriptorOffset, options, /*indexFd=*/0); ++ } ++ } ++ ++ private static ImageSearcher createFromModelFdAndOptionsImpl(final int modelDescriptor, ++ final long modelDescriptorLength, final long modelDescriptorOffset, ++ final ImageSearcherOptions options, final int indexFd) { ++ long nativeHandle = TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() { ++ @Override ++ public long createHandle() { ++ return initJniWithModelFdAndOptions(modelDescriptor, modelDescriptorLength, ++ modelDescriptorOffset, ++ TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()), ++ options.getSearcherOptions().getL2Normalize(), ++ options.getSearcherOptions().getQuantize(), indexFd, ++ options.getSearcherOptions().getMaxResults()); ++ } ++ }, IMAGE_SEARCHER_NATIVE_LIB); ++ return new ImageSearcher(nativeHandle); ++ } ++ ++ private static native long initJniWithModelFdAndOptions(int modelDescriptor, ++ long modelDescriptorLength, long modelDescriptorOffset, long baseOptionsHandle, ++ boolean l2Normalize, boolean quantize, int indexDescriptor, int maxResults); ++ ++ private static native long initJniWithByteBuffer(ByteBuffer modelBuffer, long baseOptionsHandle, ++ boolean l2Normalize, boolean quantize, int indexFileDescriptor, int maxResults); ++ ++ /** ++ * The native method to search an image based on the ROI specified. ++ * ++ * @param roi the ROI of the input image, an array representing the bounding box as {left, top, ++ * width, height} ++ */ ++ private static native List<NearestNeighbor> searchNative( ++ long nativeHandle, long frameBufferHandle, int[] roi); ++ ++ @Override ++ protected void deinit(long nativeHandle) { ++ deinitJni(nativeHandle); ++ } ++ ++ /** ++ * Native implementation to release memory pointed by the pointer. ++ * ++ * @param nativeHandle pointer to memory allocated ++ */ ++ private native void deinitJni(long nativeHandle); + } diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ColoredLabel.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ColoredLabel.java -index 0defaa9f16b96..991fedeeae9c2 100644 +index a92e70ebc09b4..7a7a5b323f43b 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ColoredLabel.java +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ColoredLabel.java @@ -17,72 +17,74 @@ package org.tensorflow.lite.task.vision.segmenter; @@ -24968,7 +28091,7 @@ + import com.google.auto.value.AutoValue; + - import org.tensorflow.lite.annotations.UsedByReflection; + import org.tensorflow.lite.task.core.annotations.UsedByReflection; /** Represents a label associated with a color for display purposes. */ @AutoValue @@ -35770,7 +38893,7 @@ return RunClassifier(env, native_handle, text); } diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/qa/bert_question_answerer_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/qa/bert_question_answerer_jni.cc -index c392c9a5a972f..401e6fbda3d9b 100644 +index 1ff0d9fc46161..b77746a2eee68 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/qa/bert_question_answerer_jni.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/qa/bert_question_answerer_jni.cc @@ -52,14 +52,19 @@ BertQuestionAnswererOptions ConvertToProtoOptions(jlong base_options_handle) { @@ -35830,6 +38953,81 @@ jstring question) { auto* question_answerer = reinterpret_cast<QuestionAnswerer*>(native_handle); +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/searcher/text_searcher_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/searcher/text_searcher_jni.cc +index 8573b0f444626..c207755d3393f 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/searcher/text_searcher_jni.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/searcher/text_searcher_jni.cc +@@ -48,7 +48,8 @@ using ::tflite::task::text::TextSearcherOptions; + + // Creates an TextSearcherOptions proto based on the Java class. + TextSearcherOptions ConvertToProtoOptions(jlong base_options_handle, +- bool l2_normalize, bool quantize, ++ bool l2_normalize, ++ bool quantize, + int index_descriptor, + int max_results) { + TextSearcherOptions proto_options; +@@ -120,7 +121,9 @@ jobject ConvertToSearchResults(JNIEnv* env, const SearchResult& results) { + + extern "C" JNIEXPORT void JNICALL + Java_org_tensorflow_lite_task_text_searcher_TextSearcher_deinitJni( +- JNIEnv* env, jobject thiz, jlong native_handle) { ++ JNIEnv* env, ++ jobject thiz, ++ jlong native_handle) { + delete reinterpret_cast<TextSearcher*>(native_handle); + } + +@@ -129,10 +132,16 @@ Java_org_tensorflow_lite_task_text_searcher_TextSearcher_deinitJni( + // values will be ignored. + extern "C" JNIEXPORT jlong JNICALL + Java_org_tensorflow_lite_task_text_searcher_TextSearcher_initJniWithModelFdAndOptions( +- JNIEnv* env, jclass thiz, jint model_descriptor, +- jlong model_descriptor_length, jlong model_descriptor_offset, +- jlong base_options_handle, bool l2_normalize, bool quantize, +- jint index_descriptor, int max_results) { ++ JNIEnv* env, ++ jclass thiz, ++ jint model_descriptor, ++ jlong model_descriptor_length, ++ jlong model_descriptor_offset, ++ jlong base_options_handle, ++ bool l2_normalize, ++ bool quantize, ++ jint index_descriptor, ++ int max_results) { + TextSearcherOptions proto_options = + ConvertToProtoOptions(base_options_handle, l2_normalize, quantize, + index_descriptor, max_results); +@@ -152,8 +161,14 @@ Java_org_tensorflow_lite_task_text_searcher_TextSearcher_initJniWithModelFdAndOp + + extern "C" JNIEXPORT jlong JNICALL + Java_org_tensorflow_lite_task_text_searcher_TextSearcher_initJniWithByteBuffer( +- JNIEnv* env, jclass thiz, jobject model_buffer, jlong base_options_handle, +- bool l2_normalize, bool quantize, jlong index_descriptor, int max_results) { ++ JNIEnv* env, ++ jclass thiz, ++ jobject model_buffer, ++ jlong base_options_handle, ++ bool l2_normalize, ++ bool quantize, ++ jlong index_descriptor, ++ int max_results) { + TextSearcherOptions proto_options = + ConvertToProtoOptions(base_options_handle, l2_normalize, quantize, + index_descriptor, max_results); +@@ -166,7 +181,10 @@ Java_org_tensorflow_lite_task_text_searcher_TextSearcher_initJniWithByteBuffer( + + extern "C" JNIEXPORT jobject JNICALL + Java_org_tensorflow_lite_task_text_searcher_TextSearcher_searchNative( +- JNIEnv* env, jclass thiz, jlong native_handle, jstring text) { ++ JNIEnv* env, ++ jclass thiz, ++ jlong native_handle, ++ jstring text) { + auto* searcher = reinterpret_cast<TextSearcher*>(native_handle); + auto results_or = searcher->Search(JStringToString(env, text)); + diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/classifier/image_classifier_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/classifier/image_classifier_jni.cc index 18e2ee1a7d4ab..2a713cf8b63cf 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/classifier/image_classifier_jni.cc @@ -36137,6 +39335,81 @@ } // namespace vision } // namespace task } // namespace tflite +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/searcher/image_searcher_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/searcher/image_searcher_jni.cc +index e57f12a16aab3..84cad5db43ea2 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/searcher/image_searcher_jni.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/searcher/image_searcher_jni.cc +@@ -52,7 +52,8 @@ using ::tflite::task::vision::ImageSearcherOptions; + + // Creates an ImageSearcherOptions proto based on the Java class. + ImageSearcherOptions ConvertToProtoOptions(jlong base_options_handle, +- bool l2_normalize, bool quantize, ++ bool l2_normalize, ++ bool quantize, + int index_descriptor, + int max_results) { + ImageSearcherOptions proto_options; +@@ -124,7 +125,9 @@ jobject ConvertToSearchResults(JNIEnv* env, const SearchResult& results) { + + extern "C" JNIEXPORT void JNICALL + Java_org_tensorflow_lite_task_vision_searcher_ImageSearcher_deinitJni( +- JNIEnv* env, jobject thiz, jlong native_handle) { ++ JNIEnv* env, ++ jobject thiz, ++ jlong native_handle) { + delete reinterpret_cast<ImageSearcher*>(native_handle); + } + +@@ -133,10 +136,16 @@ Java_org_tensorflow_lite_task_vision_searcher_ImageSearcher_deinitJni( + // values will be ignored. + extern "C" JNIEXPORT jlong JNICALL + Java_org_tensorflow_lite_task_vision_searcher_ImageSearcher_initJniWithModelFdAndOptions( +- JNIEnv* env, jclass thiz, jint model_descriptor, +- jlong model_descriptor_length, jlong model_descriptor_offset, +- jlong base_options_handle, bool l2_normalize, bool quantize, +- jint index_descriptor, int max_results) { ++ JNIEnv* env, ++ jclass thiz, ++ jint model_descriptor, ++ jlong model_descriptor_length, ++ jlong model_descriptor_offset, ++ jlong base_options_handle, ++ bool l2_normalize, ++ bool quantize, ++ jint index_descriptor, ++ int max_results) { + ImageSearcherOptions proto_options = + ConvertToProtoOptions(base_options_handle, l2_normalize, quantize, + index_descriptor, max_results); +@@ -156,8 +165,14 @@ Java_org_tensorflow_lite_task_vision_searcher_ImageSearcher_initJniWithModelFdAn + + extern "C" JNIEXPORT jlong JNICALL + Java_org_tensorflow_lite_task_vision_searcher_ImageSearcher_initJniWithByteBuffer( +- JNIEnv* env, jclass thiz, jobject model_buffer, jlong base_options_handle, +- bool l2_normalize, bool quantize, jlong index_descriptor, int max_results) { ++ JNIEnv* env, ++ jclass thiz, ++ jobject model_buffer, ++ jlong base_options_handle, ++ bool l2_normalize, ++ bool quantize, ++ jlong index_descriptor, ++ int max_results) { + ImageSearcherOptions proto_options = + ConvertToProtoOptions(base_options_handle, l2_normalize, quantize, + index_descriptor, max_results); +@@ -170,7 +185,10 @@ Java_org_tensorflow_lite_task_vision_searcher_ImageSearcher_initJniWithByteBuffe + + extern "C" JNIEXPORT jobject JNICALL + Java_org_tensorflow_lite_task_vision_searcher_ImageSearcher_searchNative( +- JNIEnv* env, jclass thiz, jlong native_handle, jlong frame_buffer_handle, ++ JNIEnv* env, ++ jclass thiz, ++ jlong native_handle, ++ jlong frame_buffer_handle, + jintArray jroi) { + auto* searcher = reinterpret_cast<ImageSearcher*>(native_handle); + // frame_buffer will be deleted after inference is done in diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/segmenter/image_segmenter_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/segmenter/image_segmenter_jni.cc index 40fa4472d37e1..8d8c8eec34295 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/segmenter/image_segmenter_jni.cc @@ -36202,35 +39475,28 @@ // frame_buffer will be deleted after inference is done in // base_vision_api_jni.cc. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.cc b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.cc -index 7a9843d61d63c..3aae0aa0ec5c7 100644 +index 65a01c0b9d33a..2a72338741626 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.cc -@@ -17,11 +17,11 @@ limitations under the License. +@@ -17,13 +17,13 @@ limitations under the License. - #include <functional> + #include <string> -#include "absl/memory/memory.h" // from @com_google_absl -#include "absl/status/status.h" // from @com_google_absl -+#include "absl/memory/memory.h" // from @com_google_absl -+#include "absl/status/status.h" // from @com_google_absl - #include "absl/strings/str_format.h" // from @com_google_absl - #include "flatbuffers/flatbuffers.h" // from @flatbuffers --#include "lib/zip.h" // from @org_libzip -+#include "lib/zip.h" // from @org_libzip +-#include "absl/strings/str_format.h" // from @com_google_absl ++#include "absl/memory/memory.h" // from @com_google_absl ++#include "absl/status/status.h" // from @com_google_absl ++#include "absl/strings/str_format.h" // from @com_google_absl + #include "absl/strings/string_view.h" // from @com_google_absl +-#include "flatbuffers/flatbuffers.h" // from @flatbuffers + #include "contrib/minizip/ioapi.h" + #include "contrib/minizip/unzip.h" ++#include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow_lite_support/cc/common.h" #include "tensorflow_lite_support/cc/port/status_macros.h" -@@ -48,7 +48,8 @@ class SimpleCleanUp { - : callback_(std::move(callback)) {} - - ~SimpleCleanUp() { -- if (callback_ != nullptr) callback_(); -+ if (callback_ != nullptr) -+ callback_(); - } - - // Use `std::move(simple_cleanup).Cancel()` to prevent the callback from ever -@@ -63,7 +64,8 @@ class SimpleCleanUp { +@@ -46,7 +46,8 @@ using ::tflite::support::TfLiteSupportStatus; // Util to get item from src_vector specified by index. template <typename T> const T* GetItemFromVector( @@ -36240,7 +39506,7 @@ if (src_vector == nullptr || index < 0 || index >= src_vector->size()) { return nullptr; } -@@ -111,7 +113,8 @@ ModelMetadataExtractor::FindFirstProcessUnit( +@@ -158,7 +159,8 @@ ModelMetadataExtractor::FindFirstProcessUnit( /* static */ std::string ModelMetadataExtractor::FindFirstAssociatedFileName( const tflite::TensorMetadata& tensor_metadata, @@ -36250,7 +39516,7 @@ if (tensor_metadata.associated_files() == nullptr) { return std::string(); } -@@ -128,7 +131,8 @@ std::string ModelMetadataExtractor::FindFirstAssociatedFileName( +@@ -175,7 +177,8 @@ std::string ModelMetadataExtractor::FindFirstAssociatedFileName( } absl::Status ModelMetadataExtractor::InitFromModelBuffer( @@ -36260,18 +39526,18 @@ // Rely on the simplest, base flatbuffers verifier. Here is not the place to // e.g. use an OpResolver: we just want to make sure the buffer is valid to // access the metadata. -@@ -187,7 +191,8 @@ absl::Status ModelMetadataExtractor::InitFromModelBuffer( +@@ -234,7 +237,8 @@ absl::Status ModelMetadataExtractor::InitFromModelBuffer( } absl::Status ModelMetadataExtractor::ExtractAssociatedFiles( - const char* buffer_data, size_t buffer_size) { + const char* buffer_data, + size_t buffer_size) { - // Setup libzip error reporting. - zip_error_t error; - zip_error_init(&error); + // Create in-memory read-only zip file. + ZipReadOnlyMemFile mem_file = ZipReadOnlyMemFile(buffer_data, buffer_size); + // Open zip. diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.h b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.h -index bff8cdf5ef43e..dc9a992aee2be 100644 +index c2b28d18ef7d8..007919d581431 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.h @@ -16,8 +16,8 @@ limitations under the License. @@ -36286,7 +39552,7 @@ #include "tensorflow_lite_support/cc/port/statusor.h" #include "tensorflow_lite_support/metadata/metadata_schema_generated.h" diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_populator.cc b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_populator.cc -index e21d426369e2e..2841c730adfd1 100644 +index 9d256b3322fb0..299ade3e95d54 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_populator.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_populator.cc @@ -19,9 +19,9 @@ limitations under the License. @@ -36301,7 +39567,7 @@ #include "tensorflow_lite_support/cc/common.h" #include "tensorflow_lite_support/cc/port/status_macros.h" diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_populator.h b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_populator.h -index a18e19bdb7973..9037f5853744b 100644 +index 510e6c04cdda1..4410f8481f97d 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_populator.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_populator.h @@ -17,8 +17,8 @@ limitations under the License. @@ -36315,7 +39581,7 @@ #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow_lite_support/cc/port/statusor.h" #include "tensorflow_lite_support/metadata/metadata_schema_generated.h" -@@ -77,7 +77,8 @@ class ModelMetadataPopulator { +@@ -79,7 +79,8 @@ class ModelMetadataPopulator { // Zips and appends associated files to the provided model buffer. Called // internally by `Populate()`. tflite::support::StatusOr<std::string> AppendAssociatedFiles( @@ -36411,80 +39677,116 @@ if (table == nullptr) { // Should never happen, because VerifyModelMetadataBuffer has verified it. TFLITE_LOG(FATAL) << "The ModelMetadata object is null."; -diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_mem_file.cc b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_mem_file.cc -index 2e4d9107c8c31..f2b07e2054dfb 100644 ---- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_mem_file.cc -+++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_mem_file.cc -@@ -36,9 +36,13 @@ ZipMemFile::ZipMemFile(const char* buffer, size_t size) - zlib_filefunc_def_.opaque = this; +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_readonly_mem_file.cc b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_readonly_mem_file.cc +index 3dac8c24af942..392b6b411fe03 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_readonly_mem_file.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_readonly_mem_file.cc +@@ -41,14 +41,17 @@ zlib_filefunc64_def& ZipReadOnlyMemFile::GetFileFunc64Def() { } --zlib_filefunc_def& ZipMemFile::GetFileFuncDef() { return zlib_filefunc_def_; } -+zlib_filefunc_def& ZipMemFile::GetFileFuncDef() { -+ return zlib_filefunc_def_; -+} + /* static */ +-voidpf ZipReadOnlyMemFile::OpenFile(voidpf opaque, const void* filename, ++voidpf ZipReadOnlyMemFile::OpenFile(voidpf opaque, ++ const void* filename, + int mode) { + // Result is never used, but needs to be non-null for `zipOpen2` not to fail. + return opaque; + } --absl::string_view ZipMemFile::GetFileContent() const { return data_; } -+absl::string_view ZipMemFile::GetFileContent() const { + /* static */ +-uLong ZipReadOnlyMemFile::ReadFile(voidpf opaque, voidpf stream, void* buf, ++uLong ZipReadOnlyMemFile::ReadFile(voidpf opaque, ++ voidpf stream, ++ void* buf, + uLong size) { + auto* mem_file = static_cast<ZipReadOnlyMemFile*>(opaque); + if (mem_file->offset_ < 0 || mem_file->Size() < mem_file->offset_) { +@@ -65,8 +68,10 @@ uLong ZipReadOnlyMemFile::ReadFile(voidpf opaque, voidpf stream, void* buf, + } + + /* static */ +-uLong ZipReadOnlyMemFile::WriteFile(voidpf opaque, voidpf stream, +- const void* buf, uLong size) { ++uLong ZipReadOnlyMemFile::WriteFile(voidpf opaque, ++ voidpf stream, ++ const void* buf, ++ uLong size) { + // File is not writable. + return 0; + } +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_readonly_mem_file.h b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_readonly_mem_file.h +index 13927a7afa698..a1799ff509de5 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_readonly_mem_file.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_readonly_mem_file.h +@@ -58,7 +58,9 @@ class ZipReadOnlyMemFile { + // The file function implementations used in the `zlib_filefunc64_def`. + static voidpf OpenFile(voidpf opaque, const void* filename, int mode); + static uLong ReadFile(voidpf opaque, voidpf stream, void* buf, uLong size); +- static uLong WriteFile(voidpf opaque, voidpf stream, const void* buf, ++ static uLong WriteFile(voidpf opaque, ++ voidpf stream, ++ const void* buf, + uLong size); + static ZPOS64_T TellFile(voidpf opaque, voidpf stream); + static long SeekFile // NOLINT +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_writable_mem_file.cc b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_writable_mem_file.cc +index 5999be028689a..38ad17ad8935c 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_writable_mem_file.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_writable_mem_file.cc +@@ -40,17 +40,22 @@ zlib_filefunc64_def& ZipWritableMemFile::GetFileFunc64Def() { + return zlib_filefunc64_def_; + } + +-absl::string_view ZipWritableMemFile::GetFileContent() const { return data_; } ++absl::string_view ZipWritableMemFile::GetFileContent() const { + return data_; +} /* static */ - voidpf ZipMemFile::OpenFile(voidpf opaque, const char* filename, int mode) { -@@ -47,7 +51,9 @@ voidpf ZipMemFile::OpenFile(voidpf opaque, const char* filename, int mode) { +-voidpf ZipWritableMemFile::OpenFile(voidpf opaque, const void* filename, ++voidpf ZipWritableMemFile::OpenFile(voidpf opaque, ++ const void* filename, + int mode) { + // Result is never used, but needs to be non-null for `zipOpen2` not to fail. + return opaque; } /* static */ --size_t ZipMemFile::ReadFile(voidpf opaque, voidpf stream, void* buf, -+size_t ZipMemFile::ReadFile(voidpf opaque, -+ voidpf stream, -+ void* buf, - size_t size) { - auto* mem_file = static_cast<ZipMemFile*>(opaque); +-uLong ZipWritableMemFile::ReadFile(voidpf opaque, voidpf stream, void* buf, ++uLong ZipWritableMemFile::ReadFile(voidpf opaque, ++ voidpf stream, ++ void* buf, + uLong size) { + auto* mem_file = static_cast<ZipWritableMemFile*>(opaque); if (mem_file->offset_ < 0 || mem_file->Size() < mem_file->offset_) { -@@ -64,7 +70,9 @@ size_t ZipMemFile::ReadFile(voidpf opaque, voidpf stream, void* buf, +@@ -67,8 +72,10 @@ uLong ZipWritableMemFile::ReadFile(voidpf opaque, voidpf stream, void* buf, } /* static */ --size_t ZipMemFile::WriteFile(voidpf opaque, voidpf stream, const void* buf, -+size_t ZipMemFile::WriteFile(voidpf opaque, -+ voidpf stream, -+ const void* buf, - size_t size) { - auto* mem_file = static_cast<ZipMemFile*>(opaque); +-uLong ZipWritableMemFile::WriteFile(voidpf opaque, voidpf stream, +- const void* buf, uLong size) { ++uLong ZipWritableMemFile::WriteFile(voidpf opaque, ++ voidpf stream, ++ const void* buf, ++ uLong size) { + auto* mem_file = static_cast<ZipWritableMemFile*>(opaque); if (mem_file->offset_ + size > mem_file->Size()) { -@@ -82,7 +90,9 @@ ptrdiff_t ZipMemFile::TellFile(voidpf opaque, voidpf stream) { - } - - /* static */ --ptrdiff_t ZipMemFile::SeekFile(voidpf opaque, voidpf stream, size_t offset, -+ptrdiff_t ZipMemFile::SeekFile(voidpf opaque, -+ voidpf stream, -+ size_t offset, - int origin) { - auto* mem_file = static_cast<ZipMemFile*>(opaque); - switch (origin) { -diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_mem_file.h b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_mem_file.h -index ef7843d70cff6..d6748fcbe9ee1 100644 ---- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_mem_file.h -+++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_mem_file.h -@@ -56,10 +56,14 @@ class ZipMemFile { - // The file function implementations used in the `zlib_filefunc_def`. - static voidpf OpenFile(voidpf opaque, const char* filename, int mode); - static size_t ReadFile(voidpf opaque, voidpf stream, void* buf, size_t size); -- static size_t WriteFile(voidpf opaque, voidpf stream, const void* buf, -+ static size_t WriteFile(voidpf opaque, -+ voidpf stream, -+ const void* buf, - size_t size); - static ptrdiff_t TellFile(voidpf opaque, voidpf stream); -- static ptrdiff_t SeekFile(voidpf opaque, voidpf stream, size_t offset, -+ static ptrdiff_t SeekFile(voidpf opaque, -+ voidpf stream, -+ size_t offset, - int origin); - static int CloseFile(voidpf opaque, voidpf stream); - static int ErrorFile(voidpf opaque, voidpf stream); + mem_file->data_.resize(mem_file->offset_ + size); +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_writable_mem_file.h b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_writable_mem_file.h +index 762dd58f0fb41..30e42fdb72a31 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_writable_mem_file.h ++++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_writable_mem_file.h +@@ -59,7 +59,9 @@ class ZipWritableMemFile { + // The file function implementations used in the `zlib_filefunc64_def`. + static voidpf OpenFile(voidpf opaque, const void* filename, int mode); + static uLong ReadFile(voidpf opaque, voidpf stream, void* buf, uLong size); +- static uLong WriteFile(voidpf opaque, voidpf stream, const void* buf, ++ static uLong WriteFile(voidpf opaque, ++ voidpf stream, ++ const void* buf, + uLong size); + static ZPOS64_T TellFile(voidpf opaque, voidpf stream); + static long SeekFile // NOLINT diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/flatbuffers_lib/flatbuffers_lib.cc b/third_party/tflite_support/src/tensorflow_lite_support/metadata/flatbuffers_lib/flatbuffers_lib.cc index 6185722504f69..8e00452bea983 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/flatbuffers_lib/flatbuffers_lib.cc @@ -42584,7 +45886,7 @@ + } } diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/apis/GMLImage.h b/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/apis/GMLImage.h -index 110186bb63a1b..0c494915e7357 100644 +index 110186bb63a1b..18797d8135eb8 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/apis/GMLImage.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/apis/GMLImage.h @@ -19,7 +19,8 @@ @@ -42592,8 +45894,8 @@ /** Types of image sources. */ -typedef NSInteger GMLImageSourceType NS_TYPED_ENUM NS_SWIFT_NAME(MLImageSourceType); -+typedef NSInteger GMLImageSourceType NS_TYPED_ENUM -+ NS_SWIFT_NAME(MLImageSourceType); ++typedef NSInteger GMLImageSourceType ++ NS_TYPED_ENUM NS_SWIFT_NAME(MLImageSourceType); /** Image source is a `UIImage`. */ static const GMLImageSourceType GMLImageSourceTypeImage = 0; /** Image source is a `CVPixelBuffer`. */ @@ -45620,6 +48922,18 @@ .def_property_readonly("audio_format", &AudioBuffer::GetAudioFormat) .def_property_readonly("buffer_size", &AudioBuffer::GetBufferSize) .def_property_readonly("float_buffer", [](AudioBuffer& self) { +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/pybinds/_pywrap_audio_classifier.cc b/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/pybinds/_pywrap_audio_classifier.cc +index 5d94db2a01b37..e2054cf645c08 100644 +--- a/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/pybinds/_pywrap_audio_classifier.cc ++++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/pybinds/_pywrap_audio_classifier.cc +@@ -20,7 +20,6 @@ limitations under the License. + #include "tensorflow_lite_support/cc/task/audio/proto/classifications_proto_inc.h" + #include "tensorflow_lite_support/cc/task/processor/proto/classification_options.pb.h" + #include "tensorflow_lite_support/cc/task/processor/proto/classifications.pb.h" +-#include "tensorflow_lite_support/cc/task/processor/proto/classifications.pb.h" + #include "tensorflow_lite_support/python/task/core/pybinds/task_utils.h" + + namespace tflite { diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/pybinds/_pywrap_audio_embedder.cc b/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/pybinds/_pywrap_audio_embedder.cc index 50e0b4f7ce4a8..8b1d67d9f8e05 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/pybinds/_pywrap_audio_embedder.cc @@ -45664,7 +48978,7 @@ .def("get_number_of_output_layers", &AudioEmbedder::GetNumberOfOutputLayers) diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/core/pybinds/image_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/core/pybinds/image_utils.cc -index 8834d0e36816d..3b6bf2fc44dc5 100644 +index 977b4e16175ac..124f5cb1ad15d 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/core/pybinds/image_utils.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/core/pybinds/image_utils.cc @@ -43,13 +43,13 @@ PYBIND11_MODULE(image_utils, m) { @@ -45684,19 +48998,9 @@ data.pixel_data, sizeof(uint8), py::format_descriptor<uint8>::format(), 3, diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_image_classifier.cc b/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_image_classifier.cc -index 4497ba39623b2..f3f478d6f4f74 100644 +index 4ca20a363345e..b4f23baa6e0b1 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_image_classifier.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_image_classifier.cc -@@ -16,8 +16,8 @@ limitations under the License. - #include "pybind11/pybind11.h" - #include "pybind11_protobuf/native_proto_caster.h" // from @pybind11_protobuf - #include "tensorflow_lite_support/cc/task/processor/proto/bounding_box.pb.h" --#include "tensorflow_lite_support/cc/task/processor/proto/classifications.pb.h" - #include "tensorflow_lite_support/cc/task/processor/proto/classification_options.pb.h" -+#include "tensorflow_lite_support/cc/task/processor/proto/classifications.pb.h" - #include "tensorflow_lite_support/cc/task/vision/image_classifier.h" - #include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h" - #include "tensorflow_lite_support/python/task/core/pybinds/task_utils.h" @@ -67,17 +67,17 @@ PYBIND11_MODULE(_pywrap_image_classifier, m) { return core::get_value(classifier); }) @@ -45734,7 +49038,7 @@ }); } diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_image_segmenter.cc b/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_image_segmenter.cc -index 0ba816479cbae..19d1f31b2e78c 100644 +index 3ebf09fb4f284..e71048e9ebb0b 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_image_segmenter.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_image_segmenter.cc @@ -47,23 +47,23 @@ PYBIND11_MODULE(_pywrap_image_segmenter, m) { @@ -45768,7 +49072,7 @@ }); } diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_object_detector.cc b/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_object_detector.cc -index d05fb32cf9d56..36fa2372e60af 100644 +index 39e39c9df00e1..3749efc811019 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_object_detector.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_object_detector.cc @@ -65,17 +65,16 @@ PYBIND11_MODULE(_pywrap_object_detector, m) { @@ -46148,48 +49452,6 @@ } // namespace -diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/serialized_searcher.proto b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/serialized_searcher.proto -index 5e0bfa4738ecd..af0c372b30a41 100644 ---- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/serialized_searcher.proto -+++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/serialized_searcher.proto -@@ -23,9 +23,7 @@ enum DistanceMeasure { - } - - message PartitionerProto { -- message Leaf { -- repeated float dimension = 1 [packed = true]; -- } -+ message Leaf { repeated float dimension = 1 [packed = true]; } - - repeated Leaf leaf = 1; - -@@ -35,13 +33,9 @@ message PartitionerProto { - } - - message AsymmetricHashingProto { -- message CodebookEntry { -- repeated float dimension = 1 [packed = true]; -- } -+ message CodebookEntry { repeated float dimension = 1 [packed = true]; } - -- message SubspaceCodebook { -- repeated CodebookEntry entry = 1; -- } -+ message SubspaceCodebook { repeated CodebookEntry entry = 1; } - - repeated SubspaceCodebook subspace = 1; - -@@ -56,9 +50,7 @@ message AsymmetricHashingProto { - optional LookupType lookup_type = 3 [default = FLOAT]; - } - message IndexerProto { -- oneof indexer { -- AsymmetricHashingProto asymmetric_hashing = 1; -- } -+ oneof indexer { AsymmetricHashingProto asymmetric_hashing = 1; } - } - message ScannOnDeviceConfig { - optional DistanceMeasure query_distance = 1 [default = UNSPECIFIED]; diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/top_n_amortized_constant.h b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/top_n_amortized_constant.h index 8f53ddf0669c4..3e5a6b00736d0 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/top_n_amortized_constant.h @@ -46307,10 +49569,10 @@ // Parses and returns the `IndexConfig` stored in the index file. absl::StatusOr<IndexConfig> GetIndexConfig() const; diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index_builder.cc b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index_builder.cc -index 0ebc6602c78f0..c77f7299e64a6 100644 +index fe5d1ef1175e4..0d802024c2b01 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index_builder.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index_builder.cc -@@ -20,13 +20,13 @@ limitations under the License. +@@ -21,13 +21,13 @@ limitations under the License. #include <vector> #include "absl/container/btree_map.h" // from @com_google_absl @@ -46331,20 +49593,20 @@ #include "tensorflow_lite_support/cc/port/status_macros.h" #include "tensorflow_lite_support/scann_ondevice/cc/mem_writable_file.h" #include "tensorflow_lite_support/scann_ondevice/cc/utils.h" -@@ -55,8 +55,10 @@ template <typename T> +@@ -56,8 +56,10 @@ template <typename T> absl::StatusOr<std::string> CreateIndexBufferImpl( absl::Span<const T> database, - absl::Span<const uint32_t> partition_assignment, + absl::optional<absl::Span<const uint32_t>> partition_assignment, - absl::Span<const std::string> metadata, const std::string& userinfo, - IndexConfig index_config, bool compression) { + absl::Span<const std::string> metadata, + const std::string& userinfo, + IndexConfig index_config, + bool compression) { - if (partition_assignment.size() != metadata.size()) { - return absl::InvalidArgumentError( - "Size of partition assignment and metadata mismatch"); -@@ -142,8 +144,8 @@ absl::StatusOr<std::string> CreateIndexBufferImpl( + size_t num_partitions = 1; + if (partition_assignment) { + if (partition_assignment->size() != metadata.size()) { +@@ -145,8 +147,8 @@ absl::StatusOr<std::string> CreateIndexBufferImpl( } // namespace @@ -46356,7 +49618,7 @@ artifacts.float_database.has_value()) { return absl::InvalidArgumentError( diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index_builder.h b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index_builder.h -index 311feb7992f79..5701796943e28 100644 +index e8f8f06220578..53cac9b583da4 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index_builder.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index_builder.h @@ -16,12 +16,12 @@ limitations under the License. @@ -46377,7 +49639,7 @@ #include "tensorflow_lite_support/scann_ondevice/proto/index_config.pb.h" namespace tflite { -@@ -59,8 +59,8 @@ struct IndexedArtifacts { +@@ -60,8 +60,8 @@ struct IndexedArtifacts { // Creates a byte buffer for the index file from the artifacts. Returns errors // when there are not exactly one database specified, or other issues with input // such as shape mismatch, invalid partition indices etc. @@ -46481,7 +49743,7 @@ namespace pybind11 { diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/index_builder_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/index_builder_test.cc -index 2fa253b58e19c..68830a9976e41 100644 +index 07da739f4a888..a1af840cc2f14 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/index_builder_test.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/index_builder_test.cc @@ -18,18 +18,18 @@ limitations under the License. @@ -46514,7 +49776,7 @@ #include "tensorflow_lite_support/cc/port/gmock.h" #include "tensorflow_lite_support/cc/port/gtest.h" #include "tensorflow_lite_support/cc/port/status_matchers.h" -@@ -125,22 +125,23 @@ TEST_P(PopulateIndexFileTest, WritesHashedDatabase) { +@@ -137,22 +137,23 @@ TEST_P(PopulateIndexFileTest, WritesHashedDatabaseWithPartitioner) { { tflite::scann_ondevice::core::ScannOnDeviceConfig config = @@ -46554,7 +49816,7 @@ std::vector<uint8_t> hashed_database; hashed_database.reserve(kNumEmbeddings * kDimensions); for (int i = 0; i < kNumEmbeddings; ++i) { -@@ -190,15 +191,17 @@ TEST_P(PopulateIndexFileTest, WritesHashedDatabase) { +@@ -202,16 +203,18 @@ TEST_P(PopulateIndexFileTest, WritesHashedDatabaseWithPartitioner) { auto hashed_table_iterator = absl::WrapUnique(hashed_table->NewIterator(leveldb::ReadOptions())); @@ -46565,8 +49827,9 @@ + LookupKey(hashed_table_iterator.get(), "INDEX_CONFIG")); IndexConfig index_config; EXPECT_TRUE(index_config.ParseFromString(serialized_config)); - EXPECT_THAT(index_config, - EqualsProto(CreateExpectedConfig(IndexConfig::UINT8))); + EXPECT_THAT( + index_config, + EqualsProto(CreateExpectedConfigWithPartitioner(IndexConfig::UINT8))); - SUPPORT_ASSERT_OK_AND_ASSIGN(std::string userinfo, - LookupKey(hashed_table_iterator.get(), "USER_INFO")); @@ -46576,7 +49839,50 @@ EXPECT_EQ(userinfo, "hashed_userinfo"); // Partition assignment is based on i % kNumPartitions, so: -@@ -240,22 +243,23 @@ TEST_P(PopulateIndexFileTest, WritesFloatDatabase) { +@@ -253,9 +256,10 @@ TEST_P(PopulateIndexFileTest, WritesHashedDatabaseWithoutPartitioner) { + + { + tflite::scann_ondevice::core::ScannOnDeviceConfig config = +- ParseTextProtoOrDie<tflite::scann_ondevice::core::ScannOnDeviceConfig>(R"pb( +- query_distance: SQUARED_L2_DISTANCE +- )pb"); ++ ParseTextProtoOrDie<tflite::scann_ondevice::core::ScannOnDeviceConfig>( ++ R"pb( ++ query_distance: SQUARED_L2_DISTANCE ++ )pb"); + std::vector<uint8_t> hashed_database; + hashed_database.reserve(kNumEmbeddings * kDimensions); + for (int i = 0; i < kNumEmbeddings; ++i) { +@@ -299,22 +303,23 @@ TEST_P(PopulateIndexFileTest, WritesHashedDatabaseWithoutPartitioner) { + auto float_table_iterator = + absl::WrapUnique(float_table->NewIterator(leveldb::ReadOptions())); + +- SUPPORT_ASSERT_OK_AND_ASSIGN(std::string serialized_config, +- LookupKey(float_table_iterator.get(), "INDEX_CONFIG")); ++ SUPPORT_ASSERT_OK_AND_ASSIGN( ++ std::string serialized_config, ++ LookupKey(float_table_iterator.get(), "INDEX_CONFIG")); + IndexConfig index_config; + EXPECT_TRUE(index_config.ParseFromString(serialized_config)); + EXPECT_THAT( + index_config, + EqualsProto(CreateExpectedConfigWithoutPartitioner(IndexConfig::UINT8))); + +- SUPPORT_ASSERT_OK_AND_ASSIGN(std::string userinfo, +- LookupKey(float_table_iterator.get(), "USER_INFO")); ++ SUPPORT_ASSERT_OK_AND_ASSIGN( ++ std::string userinfo, LookupKey(float_table_iterator.get(), "USER_INFO")); + EXPECT_EQ(userinfo, "hashed_userinfo"); + + // Check that the unique embedding partition has the exact same contents as + // the database used at construction time. + SUPPORT_ASSERT_OK_AND_ASSIGN(std::string raw_partition_hashed, +- LookupKey(float_table_iterator.get(), "E_0")); ++ LookupKey(float_table_iterator.get(), "E_0")); + std::vector<char> hashed_partition(raw_partition_hashed.begin(), + raw_partition_hashed.end()); + std::vector<char> expected; +@@ -342,22 +347,23 @@ TEST_P(PopulateIndexFileTest, WritesFloatDatabaseWithPartitioner) { { tflite::scann_ondevice::core::ScannOnDeviceConfig config = @@ -46616,7 +49922,7 @@ std::vector<float> float_database; float_database.reserve(kNumEmbeddings * kDimensions); for (int i = 0; i < kNumEmbeddings; ++i) { -@@ -305,15 +309,16 @@ TEST_P(PopulateIndexFileTest, WritesFloatDatabase) { +@@ -407,16 +413,17 @@ TEST_P(PopulateIndexFileTest, WritesFloatDatabaseWithPartitioner) { auto float_table_iterator = absl::WrapUnique(float_table->NewIterator(leveldb::ReadOptions())); @@ -46627,8 +49933,9 @@ + LookupKey(float_table_iterator.get(), "INDEX_CONFIG")); IndexConfig index_config; EXPECT_TRUE(index_config.ParseFromString(serialized_config)); - EXPECT_THAT(index_config, - EqualsProto(CreateExpectedConfig(IndexConfig::FLOAT))); + EXPECT_THAT( + index_config, + EqualsProto(CreateExpectedConfigWithPartitioner(IndexConfig::FLOAT))); - SUPPORT_ASSERT_OK_AND_ASSIGN(std::string userinfo, - LookupKey(float_table_iterator.get(), "USER_INFO")); @@ -46637,6 +49944,49 @@ EXPECT_EQ(userinfo, "float_userinfo"); // Partition assignment is based on i % kNumPartitions, so: +@@ -461,9 +468,10 @@ TEST_P(PopulateIndexFileTest, WritesFloatDatabaseWithoutPartitioner) { + + { + tflite::scann_ondevice::core::ScannOnDeviceConfig config = +- ParseTextProtoOrDie<tflite::scann_ondevice::core::ScannOnDeviceConfig>(R"pb( +- query_distance: SQUARED_L2_DISTANCE +- )pb"); ++ ParseTextProtoOrDie<tflite::scann_ondevice::core::ScannOnDeviceConfig>( ++ R"pb( ++ query_distance: SQUARED_L2_DISTANCE ++ )pb"); + std::vector<float> float_database; + float_database.reserve(kNumEmbeddings * kDimensions); + for (int i = 0; i < kNumEmbeddings; ++i) { +@@ -506,22 +514,23 @@ TEST_P(PopulateIndexFileTest, WritesFloatDatabaseWithoutPartitioner) { + auto float_table_iterator = + absl::WrapUnique(float_table->NewIterator(leveldb::ReadOptions())); + +- SUPPORT_ASSERT_OK_AND_ASSIGN(std::string serialized_config, +- LookupKey(float_table_iterator.get(), "INDEX_CONFIG")); ++ SUPPORT_ASSERT_OK_AND_ASSIGN( ++ std::string serialized_config, ++ LookupKey(float_table_iterator.get(), "INDEX_CONFIG")); + IndexConfig index_config; + EXPECT_TRUE(index_config.ParseFromString(serialized_config)); + EXPECT_THAT( + index_config, + EqualsProto(CreateExpectedConfigWithoutPartitioner(IndexConfig::FLOAT))); + +- SUPPORT_ASSERT_OK_AND_ASSIGN(std::string userinfo, +- LookupKey(float_table_iterator.get(), "USER_INFO")); ++ SUPPORT_ASSERT_OK_AND_ASSIGN( ++ std::string userinfo, LookupKey(float_table_iterator.get(), "USER_INFO")); + EXPECT_EQ(userinfo, "float_userinfo"); + + // Check that the unique embedding partition has the exact same contents as + // the database used at construction time. + SUPPORT_ASSERT_OK_AND_ASSIGN(std::string raw_partition_float, +- LookupKey(float_table_iterator.get(), "E_0")); ++ LookupKey(float_table_iterator.get(), "E_0")); + const float* raw_partition_float_ptr = + reinterpret_cast<const float*>(raw_partition_float.data()); + std::vector<float> float_partition( diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/index_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/index_test.cc index 983dd8d2bc8e8..cc1225f679f66 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/index_test.cc @@ -46743,11 +50093,11 @@ ASSERT_TRUE(mem_writable_file->Append("aaa").ok()); EXPECT_EQ(buffer, "aaa"); diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/python/leveldb_testing_utils_py_wrapper.cc b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/python/leveldb_testing_utils_py_wrapper.cc -index 2f5208b216530..1b6906bf49bc2 100644 +index ca364e06e7d1d..1ae7e0ce9ed09 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/python/leveldb_testing_utils_py_wrapper.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/python/leveldb_testing_utils_py_wrapper.cc -@@ -15,17 +15,17 @@ limitations under the License. - +@@ -16,17 +16,17 @@ limitations under the License. + #include <cstdint> #include <vector> -#include "absl/memory/memory.h" // from @com_google_absl @@ -46818,5 +50168,5 @@ #ifdef __cplusplus } -- -2.35.1.1178.g4f1659d476-goog +2.36.1.124.g0e6072fb45-goog
diff --git a/third_party/tflite_support/patches/0009-remove-unbuilt-files-with-presubmit-errors.patch b/third_party/tflite_support/patches/0009-remove-unbuilt-files-with-presubmit-errors.patch new file mode 100644 index 0000000..af9ce613 --- /dev/null +++ b/third_party/tflite_support/patches/0009-remove-unbuilt-files-with-presubmit-errors.patch
@@ -0,0 +1,31 @@ +From d1f39c9403c944a3e1df3952d452f2a314157ca0 Mon Sep 17 00:00:00 2001 +From: Robert Ogden <robertogden@chromium.org> +Date: Wed, 25 May 2022 11:05:14 -0700 +Subject: [PATCH] remove unbuilt files with presubmit errors + +--- + .../task/text/mobilebert_searcher.tflite | Bin 26199109 -> 0 bytes + .../ios/utils/Sources/TFLStringUtil.mm | 26 -- + .../scann_ondevice/cc/core/simd_utils.h | 303 ------------------ + .../tools/ci_build/common_win.bat | 29 -- + 4 files changed, 358 deletions(-) + delete mode 100644 third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/mobilebert_searcher.tflite + delete mode 100644 third_party/tflite_support/src/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.mm + delete mode 100644 third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/simd_utils.h + delete mode 100644 third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/common_win.bat + +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/mobilebert_searcher.tflite b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/mobilebert_searcher.tflite +deleted file mode 100644 +index 43624e79655bccc9ea126f291c26fa633aa2cc79..0000000000000000000000000000000000000000 +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.mm b/third_party/tflite_support/src/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.mm +deleted file mode 100644 +index 2a11bb6730474..0000000000000 +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/simd_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/simd_utils.h +deleted file mode 100644 +index f239ec482382e..0000000000000 +diff --git a/third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/common_win.bat b/third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/common_win.bat +deleted file mode 100644 +index 116e0648d5f61..0000000000000 +-- +2.36.1.124.g0e6072fb45-goog +
diff --git a/third_party/tflite_support/src/.bazelrc b/third_party/tflite_support/src/.bazelrc index 07a32152..e49087ff 100644 --- a/third_party/tflite_support/src/.bazelrc +++ b/third_party/tflite_support/src/.bazelrc
@@ -113,11 +113,6 @@ build --spawn_strategy=local build -c opt -# Adding "--cxxopt=-D_GLIBCXX_USE_CXX11_ABI=0" creates parity with TF -# compilation options. It also addresses memory use due to -# copy-on-write semantics of std::strings of the older ABI. -build --cxxopt=-D_GLIBCXX_USE_CXX11_ABI=0 - # Make Bazel print out all options from rc files. build --announce_rc @@ -138,11 +133,6 @@ # archives in -whole_archive -no_whole_archive. build --noincompatible_remove_legacy_whole_archive -# These are bazel 2.0's incompatible flags. Tensorflow needs to use bazel 2.0.0 -# to use cc_shared_library, as part of the Tensorflow Build Improvements RFC: -# https://github.com/tensorflow/community/pull/179 -build --noincompatible_prohibit_aapt1 - # Build TF with C++ 17 features. build:c++17 --cxxopt=-std=c++1z build:c++17 --cxxopt=-stdlib=libc++
diff --git a/third_party/tflite_support/src/.bazelversion b/third_party/tflite_support/src/.bazelversion index 078bf8b7..3bff059 100644 --- a/third_party/tflite_support/src/.bazelversion +++ b/third_party/tflite_support/src/.bazelversion
@@ -1 +1 @@ -4.2.2 \ No newline at end of file +5.1.1 \ No newline at end of file
diff --git a/third_party/tflite_support/src/WORKSPACE b/third_party/tflite_support/src/WORKSPACE index 342154e2..90eee1f 100644 --- a/third_party/tflite_support/src/WORKSPACE +++ b/third_party/tflite_support/src/WORKSPACE
@@ -86,24 +86,24 @@ # https://github.com/bazelbuild/rules_apple/releases http_archive( name = "build_bazel_rules_apple", - sha256 = "0052d452af7742c8f3a4e0929763388a66403de363775db7e90adecb2ba4944b", + sha256 = "a5f00fd89eff67291f6cd3efdc8fad30f4727e6ebb90718f3f05bbf3c3dd5ed7", urls = [ - "https://github.com/bazelbuild/rules_apple/releases/download/0.31.3/rules_apple.0.31.3.tar.gz", + "https://github.com/bazelbuild/rules_apple/releases/download/0.33.0/rules_apple.0.33.0.tar.gz", ], ) # https://github.com/bazelbuild/rules_swift/releases http_archive( name = "build_bazel_rules_swift", - sha256 = "8407fa0fd04a7ce1d6bb95e90b216404466f809eda459c23cb57b5fa1ef9d639", + sha256 = "8a49da750560b185804a4bc95c82d3f9cc4c2caf788960b0e21945759155fdd9", urls = [ - "https://github.com/bazelbuild/rules_swift/releases/download/0.21.0/rules_swift.0.21.0.tar.gz", + "https://github.com/bazelbuild/rules_swift/releases/download/0.25.0/rules_swift.0.25.0.tar.gz", ], ) -# TF on 2022-01-28. -TENSORFLOW_COMMIT = "f2c2144d767a64236261fb4e4dd45947bd5f5815" -TENSORFLOW_SHA256 = "32ba2f6ea07572fd05cdae7520fe1bc38409f1a21bb4524076df27f1e23d09c1" +# TF on 2022-04-20. +TENSORFLOW_COMMIT = "314479be97046e0db0ff7662b1fbdb17af2ef4b4" +TENSORFLOW_SHA256 = "b3c439aa7da6780956780e0cb312011416fcd476201f8033f90b3c4fc1cff7a0" http_archive( name = "org_tensorflow", sha256 = TENSORFLOW_SHA256, @@ -138,9 +138,8 @@ build_file = "@pybind11_bazel//:pybind11.BUILD", ) -# TODO(b/211393391): Updates the commit number once the new change is ready. -PP_COMMIT = "30f02dd9ccd2fc7046c36ed913ed510fd1aa7301" -PP_SHA256 = "178bcd587956b0f449fff2f46e663dc10baa6d4951a0a7f48cddfeef57d593a8" +PP_COMMIT = "3594106f2df3d725e65015ffb4c7886d6eeee683" +PP_SHA256 = "baa1f53568283630a5055c85f0898b8810f7a6431bd01bbaedd32b4c1defbcb1" http_archive( name = "pybind11_protobuf", sha256 = PP_SHA256, @@ -253,19 +252,12 @@ "http://mirror.bazel.build/zlib.net/fossils/zlib-1.2.11.tar.gz", "http://zlib.net/fossils/zlib-1.2.11.tar.gz", # 2017-01-15 ], -) - -http_archive( - name = "org_libzip", - build_file = "//third_party:libzip.BUILD", - sha256 = "a5d22f0c87a2625450eaa5e10db18b8ee4ef17042102d04c62e311993a2ba363", - strip_prefix = "libzip-rel-1-5-1", - urls = [ - # Bazel does not like the official download link at libzip.org, - # so use the GitHub release tag. - "https://mirror.bazel.build/github.com/nih-at/libzip/archive/rel-1-5-1.zip", - "https://github.com/nih-at/libzip/archive/rel-1-5-1.zip", + patches = [ + "@//third_party:zlib.patch" ], + patch_args = [ + "-p1" + ] ) http_archive(
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/BUILD index d6da77e..e48a288f 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/BUILD
@@ -89,7 +89,7 @@ ], deps = [ ":edgetpu_coral_plugin", - "//tensorflow_lite_support/examples/task/vision/desktop/utils:image_utils", + "//tensorflow_lite_support/cc/task/vision/utils:image_utils", "@com_google_googletest//:gtest_main", "@org_tensorflow//tensorflow/lite:framework", "@org_tensorflow//tensorflow/lite/c:common",
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/edgetpu_coral_plugin_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/edgetpu_coral_plugin_test.cc index cc183a6..6ac4e5c7 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/edgetpu_coral_plugin_test.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/edgetpu_coral_plugin_test.cc
@@ -21,7 +21,7 @@ #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/model.h" -#include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h" +#include "tensorflow_lite_support/cc/task/vision/utils/image_utils.h" namespace tflite { namespace delegates {
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/BUILD new file mode 100644 index 0000000..8aca2a8 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/BUILD
@@ -0,0 +1,35 @@ +load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite") + +package( + default_visibility = [ + "//tensorflow_lite_support:internal", + ], + licenses = ["notice"], # Apache 2.0 +) + +cc_library_with_tflite( + name = "audio_classifier", + srcs = [ + "audio_classifier.cc", + ], + hdrs = [ + "audio_classifier.h", + ], + tflite_deps = [ + "//tensorflow_lite_support/cc/task/audio:audio_classifier", + ], + deps = [ + "//tensorflow_lite_support/c:common", + "//tensorflow_lite_support/c:common_utils", + "//tensorflow_lite_support/c/task/audio/core:audio_buffer", + "//tensorflow_lite_support/c/task/audio/utils:audio_buffer_cpp_c_utils", + "//tensorflow_lite_support/c/task/core:base_options", + "//tensorflow_lite_support/c/task/core/utils:base_options_utils", + "//tensorflow_lite_support/c/task/processor:classification_options", + "//tensorflow_lite_support/c/task/processor:classification_result", + "//tensorflow_lite_support/c/task/processor/utils:classification_options_utils", + "//tensorflow_lite_support/cc/task/audio/proto:audio_classifier_options_cc_proto", + "//tensorflow_lite_support/cc/task/audio/proto:class_proto_inc", + "//tensorflow_lite_support/cc/task/audio/proto:classifications_proto_inc", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/audio_classifier.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/audio_classifier.cc new file mode 100644 index 0000000..3f1781a0 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/audio_classifier.cc
@@ -0,0 +1,252 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/c/task/audio/audio_classifier.h" + +#include <memory> + +#include "tensorflow_lite_support/c/common_utils.h" +#include "tensorflow_lite_support/c/task/audio/utils/audio_buffer_cpp_c_utils.h" +#include "tensorflow_lite_support/c/task/core/utils/base_options_utils.h" +#include "tensorflow_lite_support/c/task/processor/utils/classification_options_utils.h" +#include "tensorflow_lite_support/cc/task/audio/audio_classifier.h" +#include "tensorflow_lite_support/cc/task/audio/proto/audio_classifier_options.pb.h" +#include "tensorflow_lite_support/cc/task/audio/proto/class_proto_inc.h" +#include "tensorflow_lite_support/cc/task/audio/proto/classifications_proto_inc.h" + +namespace { +using ::tflite::support::StatusOr; +using ClassificationResultCpp = ::tflite::task::audio::ClassificationResult; +using ClassificationsCpp = ::tflite::task::audio::Classifications; +using ClassCpp = ::tflite::task::audio::Class; +using AudioClassifierCpp = ::tflite::task::audio::AudioClassifier; +using AudioClassifierOptionsCpp = ::tflite::task::audio::AudioClassifierOptions; +using AudioBufferCpp = ::tflite::task::audio::AudioBuffer; +using ::tflite::support::TfLiteSupportStatus; + +StatusOr<AudioClassifierOptionsCpp> CreateAudioClassifierCppOptionsFromCOptions( + const TfLiteAudioClassifierOptions* c_options) { + if (c_options == nullptr) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat("Expected non null options."), + TfLiteSupportStatus::kInvalidArgumentError); + } + + AudioClassifierOptionsCpp cpp_options = {}; + + // More file sources can be added in else ifs + if (c_options->base_options.model_file.file_path) + cpp_options.mutable_base_options()->mutable_model_file()->set_file_name( + c_options->base_options.model_file.file_path); + + // c_options->base_options.compute_settings.num_threads is expected to be + // set to value > 0 or -1. Otherwise invoking + // ImageClassifierCpp::CreateFromOptions() results in a not ok status. + cpp_options.mutable_base_options() + ->mutable_compute_settings() + ->mutable_tflite_settings() + ->mutable_cpu_settings() + ->set_num_threads( + c_options->base_options.compute_settings.cpu_settings.num_threads); + + for (int i = 0; i < c_options->classification_options.label_denylist.length; + i++) + cpp_options.add_class_name_denylist( + c_options->classification_options.label_denylist.list[i]); + + for (int i = 0; i < c_options->classification_options.label_allowlist.length; + i++) + cpp_options.add_class_name_allowlist( + c_options->classification_options.label_allowlist.list[i]); + + // Check needed since setting a nullptr for this field results in a segfault + // on invocation of ImageClassifierCpp::CreateFromOptions(). + if (c_options->classification_options.display_names_local) { + cpp_options.set_display_names_locale( + c_options->classification_options.display_names_local); + } + + // c_options->classification_options.max_results is expected to be set to -1 + // or any value > 0. Otherwise invoking + // ImageClassifierCpp::CreateFromOptions() results in a not ok status. + cpp_options.set_max_results(c_options->classification_options.max_results); + + cpp_options.set_score_threshold( + c_options->classification_options.score_threshold); + + return cpp_options; +} +} // namespace + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +struct TfLiteAudioClassifier { + std::unique_ptr<AudioClassifierCpp> impl; +}; + +TfLiteAudioClassifierOptions TfLiteAudioClassifierOptionsCreate(void) { + // Use brace-enclosed initializer list will break the Kokoro test. + TfLiteAudioClassifierOptions options; + options.classification_options = + tflite::task::processor::CreateDefaultClassificationOptions(); + options.base_options = tflite::task::core::CreateDefaultBaseOptions(); + return options; +} + +TfLiteAudioClassifier* TfLiteAudioClassifierFromOptions( + const TfLiteAudioClassifierOptions* options, + TfLiteSupportError** error) { + StatusOr<AudioClassifierOptionsCpp> cpp_option_status = + CreateAudioClassifierCppOptionsFromCOptions(options); + + if (!cpp_option_status.ok()) { + ::tflite::support::CreateTfLiteSupportErrorWithStatus( + cpp_option_status.status(), error); + return nullptr; + } + + StatusOr<std::unique_ptr<AudioClassifierCpp>> classifier_status = + AudioClassifierCpp::CreateFromOptions(cpp_option_status.value()); + + if (classifier_status.ok()) { + return new TfLiteAudioClassifier{.impl = + std::move(classifier_status.value())}; + } else { + ::tflite::support::CreateTfLiteSupportErrorWithStatus( + classifier_status.status(), error); + return nullptr; + } +} + +TfLiteClassificationResult* GetClassificationResultCStruct( + const ClassificationResultCpp& classification_result_cpp) { + auto c_classifications = + new TfLiteClassifications[classification_result_cpp + .classifications_size()]; + + for (int head = 0; head < classification_result_cpp.classifications_size(); + ++head) { + const ClassificationsCpp& classifications = + classification_result_cpp.classifications(head); + c_classifications[head].head_index = classifications.head_index(); + + if (classifications.has_head_name()) { + c_classifications[head].head_name = + strdup(classifications.head_name().c_str()); + } + + auto c_categories = new TfLiteCategory[classifications.classes_size()]; + c_classifications->size = classifications.classes_size(); + + for (int rank = 0; rank < classifications.classes_size(); ++rank) { + const ClassCpp& classification = classifications.classes(rank); + c_categories[rank].index = classification.index(); + c_categories[rank].score = classification.score(); + + if (classification.has_class_name()) + c_categories[rank].label = strdup(classification.class_name().c_str()); + else + c_categories[rank].label = nullptr; + + if (classification.has_display_name()) + c_categories[rank].display_name = + strdup(classification.display_name().c_str()); + else + c_categories[rank].display_name = nullptr; + } + c_classifications[head].categories = c_categories; + } + + auto c_classification_result = new TfLiteClassificationResult; + c_classification_result->classifications = c_classifications; + c_classification_result->size = + classification_result_cpp.classifications_size(); + + return c_classification_result; +} + +TfLiteClassificationResult* TfLiteAudioClassifierClassify( + const TfLiteAudioClassifier* classifier, + const TfLiteAudioBuffer* audio_buffer, + TfLiteSupportError** error) { + if (classifier == nullptr) { + tflite::support::CreateTfLiteSupportError( + kInvalidArgumentError, "Expected non null audio classifier.", error); + return nullptr; + } + + StatusOr<std::unique_ptr<AudioBufferCpp>> cpp_audio_buffer_status = + ::tflite::task::audio::CreateCppAudioBuffer(audio_buffer); + if (!cpp_audio_buffer_status.ok()) { + tflite::support::CreateTfLiteSupportErrorWithStatus( + cpp_audio_buffer_status.status(), error); + return nullptr; + } + + // fnc_sample(cpp_audio_buffer_status); + StatusOr<ClassificationResultCpp> cpp_classification_result_status = + classifier->impl->Classify(*(cpp_audio_buffer_status.value())); + + if (!cpp_classification_result_status.ok()) { + tflite::support::CreateTfLiteSupportErrorWithStatus( + cpp_classification_result_status.status(), error); + return nullptr; + } + + return GetClassificationResultCStruct( + cpp_classification_result_status.value()); +} + +int TfLiteAudioClassifierGetRequiredInputBufferSize( + TfLiteAudioClassifier* classifier, + TfLiteSupportError** error) { + if (classifier == nullptr) { + tflite::support::CreateTfLiteSupportError( + kInvalidArgumentError, "Expected non null audio classifier.", error); + return -1; + } + + return classifier->impl->GetRequiredInputBufferSize(); +} + +void TfLiteAudioClassifierDelete(TfLiteAudioClassifier* classifier) { + delete classifier; +} + +TfLiteAudioFormat* TfLiteAudioClassifierGetRequiredAudioFormat( + TfLiteAudioClassifier* classifier, + TfLiteSupportError** error) { + if (classifier == nullptr) { + tflite::support::CreateTfLiteSupportError( + kInvalidArgumentError, "Expected non null audio classifier.", error); + return nullptr; + } + + StatusOr<TfLiteAudioFormat*> c_audio_format = + CreateCAudioFormat(classifier->impl->GetRequiredAudioFormat()); + + if (!c_audio_format.ok()) { + return nullptr; + } + + return c_audio_format.value(); +} + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/audio_classifier.h b/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/audio_classifier.h new file mode 100644 index 0000000..6af9b27 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/audio_classifier.h
@@ -0,0 +1,209 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SUPPORT_C_TASK_AUDIO_AUDIO_CLASSIFIER_H_ +#define TENSORFLOW_LITE_SUPPORT_C_TASK_AUDIO_AUDIO_CLASSIFIER_H_ + +#include <stdint.h> + +#include "tensorflow_lite_support/c/common.h" +#include "tensorflow_lite_support/c/task/audio/core/audio_buffer.h" +#include "tensorflow_lite_support/c/task/core/base_options.h" +#include "tensorflow_lite_support/c/task/processor/classification_options.h" +#include "tensorflow_lite_support/c/task/processor/classification_result.h" + +// -------------------------------------------------------------------------- +/// C API for AudioClassifiier. +/// +/// The API leans towards simplicity and uniformity instead of convenience, as +/// most usage will be by language-specific wrappers. It provides largely the +/// same set of functionality as that of the C++ TensorFlow Lite +/// `AudioClassifier` API, but is useful for shared libraries where having +/// a stable ABI boundary is important. +/// +/// Usage: +/// <pre><code> +/// // Create the model +/// Using the options initialized with default values returned by +/// TfLiteAudioClassifierOptionsCreate() makes sure that there will be no +/// undefined behaviour due to garbage values in unitialized members. +/// TfLiteAudioClassifierOptions options = TfLiteAudioClassifierOptionsCreate(); +/// +/// Set the model file path in options +/// options.base_options.model_file.file_path = "/path/to/model.tflite"; +/// +/// If need be, set values for any options to customize behaviour. +/// options.base_options.compute_settings.cpu_settings.num_threads = 3 +/// +/// Create TfLiteAudioClassifier using the options: +/// If error information is not nedded in case of failure: +/// TfLiteAudioClassifier* audio_classifier = +/// TfLiteAudioClassifierFromOptions(&options, NULL); +/// +/// If error information is nedded in case of failure: +/// TfLiteSupportError* create_error = NULL; +/// TfLiteAudioClassifier* audio_classifier = +/// TfLiteAudioClassifierFromOptions(&options, &create_error); +/// +/// if (!audio_classifier) { +/// Handle failure. +/// Do something with `create_error`, if requested as illustrated above. +/// } +/// +/// Dispose of the create_error object. +/// TfLiteSupportErrorDelete(create_error); +/// +/// Classify an audio +/// TfLiteFrameBuffer frame_buffer = { Initialize with audio data } +/// +/// If error information is not nedded in case of failure: +/// TfLiteClassificationResult* classification_result = +/// TfLiteAudioClassifierClassify(audio_classifier, &frame_buffer, NULL); +/// +/// If error information is nedded in case of failure: +/// TfLiteSupportError* classify_error = NULL; +/// TfLiteClassificationResult* classification_result = +/// TfLiteAudioClassifierClassify(audio_classifier, &frame_buffer, +/// &classify_error); +/// +/// if (!classification_result) { +/// Handle failure. +/// Do something with `classify_error`, if requested as illustrated above. +/// } +/// +/// Dispose of the classify_error object. +/// TfLiteSupportErrorDelete(classify_error); +/// +/// Dispose of the API object. +/// TfLiteAudioClassifierDelete(audio_classifier); + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +typedef struct TfLiteAudioClassifier TfLiteAudioClassifier; + +typedef struct TfLiteAudioClassifierOptions { + TfLiteClassificationOptions classification_options; + TfLiteBaseOptions base_options; +} TfLiteAudioClassifierOptions; + +// Creates and returns TfLiteAudioClassifierOptions initialized with default +// values. Default values are as follows: +// 1. .classification_options.max_results = -1, which returns all classification +// categories by default. +// 2. .base_options.compute_settings.tflite_settings.cpu_settings.num_threads = +// -1, which makes the TFLite runtime choose the value. +// 3. .classification_options.score_threshold = 0 +// 4. All pointers like .base_options.model_file.file_path, +// .base_options.classification_options.display_names_local, +// .classification_options.label_allowlist.list, +// options.classification_options.label_denylist.list are NULL. +// 5. All other integer values are initialized to 0. +TfLiteAudioClassifierOptions TfLiteAudioClassifierOptionsCreate(void); + +// Creates TfLiteAudioClassifier from options. +// .base_options.model_file.file_path in TfLiteAudioClassifierOptions should be +// set to the path of the tflite model you wish to create the +// TfLiteAudioClassifier with. +// Create TfLiteAudioClassifierOptions using +// TfLiteAudioClassifierOptionsCreate(). If need be, you can change the default +// values of options for customizing classification, If options are not created +// in the aforementioned way, you have to make sure that all members are +// initialized to respective default values and all pointer members are zero +// initialized to avoid any undefined behaviour. +// +// Returns the created audio classifier in case of success. +// Returns nullptr on failure which happens commonly due to one of the following +// issues: +// 1. file doesn't exist or is not a well formatted. +// 2. options is nullptr. +// 3. Both options.classification_options.label_denylist and +// options.classification_options.label_allowlist are non empty. These +// fields are mutually exclusive. +// +// The caller can check if an error was encountered by testing if the returned +// value of the function is null. If the caller doesn't want the reason for +// failure, they can simply pass a NULL for the address of the error pointer as +// shown below: +// +// TfLiteAudioClassifier* classifier = TfLiteAudioClassifierFromOptions(options, +// NULL); +// +// If the caller wants to be informed of the reason for failure, they must pass +// the adress of a pointer of type TfLiteSupportError to the `error` param as +// shown below: +// +// TfLiteSupport *error = NULL: +// TfLiteAudioClassifier* classifier = TfLiteAudioClassifierFromOptions(options, +// &error); +// +// In case of unsuccessful execution, Once the function returns, the error +// pointer will point to a struct containing the error information. If error +// info is passed back to the caller, it is the responsibility of the caller to +// free the error struct by calling the following method defined in common.h: +// +// TfLiteSupportErrorDelete(error) +// +TfLiteAudioClassifier* TfLiteAudioClassifierFromOptions( + const TfLiteAudioClassifierOptions* options, + TfLiteSupportError** error); + +// Invokes the encapsulated TFLite model and classifies the frame_buffer. +// Returns a pointer to the created classification result in case of success or +// NULL in case of failure. The caller must test the return value to identify +// success or failure. If the caller doesn't want the reason for failure, they +// can simply pass a NULL for the address of the error pointer as shown below: +// +// TfLiteClassificationResult* classification_result = +// TfLiteAudioClassifierClassify(&options, NULL); +// +// If the caller wants to be informed of the reason for failure, they must pass +// the adress of a pointer of type TfLiteSupportError to the `error` param as +// shown below: +// +// TfLiteSupport *error = NULL: +// TfLiteAudioClassifier* classifier = TfLiteAudioClassifierFromOptions(options, +// &error); +// +// In case of unsuccessful execution, Once the function returns, the error +// pointer will point to a struct containing the error information. If error +// info is passed back to the caller, it is the responsibility of the caller to +// free the error struct by calling the following method defined in common.h: +// +// TfLiteSupportErrorDelete(error) +// +TfLiteClassificationResult* TfLiteAudioClassifierClassify( + const TfLiteAudioClassifier* classifier, + const TfLiteAudioBuffer* audio_buffer, + TfLiteSupportError** error); + +// Returns the input buffer size required by the audio classifier. +int TfLiteAudioClassifierGetRequiredInputBufferSize( + TfLiteAudioClassifier* classifier, + TfLiteSupportError** error); + +// Returns the audio format required by the audio classifier. +TfLiteAudioFormat* TfLiteAudioClassifierGetRequiredAudioFormat( + TfLiteAudioClassifier* classifier, + TfLiteSupportError** error); + +// Disposes off the audio classifier. +void TfLiteAudioClassifierDelete(TfLiteAudioClassifier* classifier); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // TENSORFLOW_LITE_SUPPORT_C_TASK_AUDIO_AUDIO_CLASSIFIER_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/core/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/core/BUILD new file mode 100644 index 0000000..c67cc54b --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/core/BUILD
@@ -0,0 +1,16 @@ +package( + default_visibility = [ + "//tensorflow_lite_support:internal", + ], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "audio_buffer", + srcs = [ + "audio_buffer.cc", + ], + hdrs = [ + "audio_buffer.h", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/core/audio_buffer.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/core/audio_buffer.cc new file mode 100644 index 0000000..a8437b1 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/core/audio_buffer.cc
@@ -0,0 +1,31 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/c/task/audio/core/audio_buffer.h" + +#include <stdlib.h> + +void TfLiteAudioBufferDelete(TfLiteAudioBuffer* audio_buffer) { + TfLiteAudioBufferDeleteData(*audio_buffer); + free(audio_buffer); +} + +void TfLiteAudioBufferDeleteData(const TfLiteAudioBuffer audio_buffer) { + free(audio_buffer.data); +} + +void TfLiteAudioFormatDelete(TfLiteAudioFormat* audio_format) { + free(audio_format); +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/core/audio_buffer.h b/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/core/audio_buffer.h new file mode 100644 index 0000000..471f02f --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/core/audio_buffer.h
@@ -0,0 +1,58 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SUPPORT_C_TASK_AUDIO_AUDIO_BUFFER_H_ +#define TENSORFLOW_LITE_SUPPORT_C_TASK_AUDIO_AUDIO_BUFFER_H_ + +// Defines C structs for holding the audio buffer. + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Holds audio format metadata. +typedef struct TfLiteAudioFormat { + // The number of channels in the audio buffer. + int channels; + // The sample rate of the audio buffer. + int sample_rate; +} TfLiteAudioFormat; + +// A `TfLiteAudioBuffer` provides a view into the provided backing buffer and +// the audio format metadata.. TfLiteAudioBuffer doesn't take ownership of the +// provided backing buffer. The caller is responsible to manage the backing +// buffer lifecycle for the lifetime of the TfLiteAudioBuffer. +typedef struct TfLiteAudioBuffer { + TfLiteAudioFormat format; + + // Backing buffer that holds the audio samples which are to be processed. For + // muti channel data array is expected to be interleaved . + float* data; + + // Size of the audio buffer. This size can be used to loop through the + // audio_buffer. + int size; +} TfLiteAudioBuffer; + +void TfLiteAudioBufferDelete(TfLiteAudioBuffer* buffer); + +void TfLiteAudioBufferDeleteData(const TfLiteAudioBuffer audio_buffer); + +void TfLiteAudioFormatDelete(TfLiteAudioFormat* format); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // TENSORFLOW_LITE_SUPPORT_C_TASK_AUDIO_AUDIO_BUFFER_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/utils/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/utils/BUILD new file mode 100644 index 0000000..ef9c6ff --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/utils/BUILD
@@ -0,0 +1,22 @@ +package( + default_visibility = [ + "//tensorflow_lite_support:internal", + ], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "audio_buffer_cpp_c_utils", + srcs = [ + "audio_buffer_cpp_c_utils.cc", + ], + hdrs = [ + "audio_buffer_cpp_c_utils.h", + ], + deps = [ + "//tensorflow_lite_support/c/task/audio/core:audio_buffer", + "//tensorflow_lite_support/cc:common", + "//tensorflow_lite_support/cc/task/audio/core:audio_buffer", + "@com_google_absl//absl/strings:str_format", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/utils/audio_buffer_cpp_c_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/utils/audio_buffer_cpp_c_utils.cc new file mode 100644 index 0000000..560d2736 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/utils/audio_buffer_cpp_c_utils.cc
@@ -0,0 +1,57 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/c/task/audio/utils/audio_buffer_cpp_c_utils.h" + +#include "absl/strings/str_format.h" // from @com_google_absl +#include "tensorflow_lite_support/cc/common.h" + +namespace tflite { +namespace task { +namespace audio { + +namespace { +using AudioBufferCpp = ::tflite::task::audio::AudioBuffer; +using ::tflite::support::StatusOr; +using ::tflite::support::TfLiteSupportStatus; +} // namespace + +StatusOr<std::unique_ptr<AudioBufferCpp>> CreateCppAudioBuffer( + const TfLiteAudioBuffer* audio_buffer) { + if (audio_buffer == nullptr) + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat("Expected non null audio buffer."), + TfLiteSupportStatus::kInvalidArgumentError); + + return AudioBufferCpp::Create( + audio_buffer->data, audio_buffer->size, + {audio_buffer->format.channels, + static_cast<int>(audio_buffer->format.sample_rate)}); +} + +StatusOr<TfLiteAudioFormat*> CreateCAudioFormat( + StatusOr<AudioBufferCpp::AudioFormat> cpp_audio_format) { + if (!cpp_audio_format.ok()) { + return cpp_audio_format.status(); + } + + return new TfLiteAudioFormat{cpp_audio_format->channels, + cpp_audio_format->sample_rate}; +} + +} // namespace audio +} // namespace task +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/utils/audio_buffer_cpp_c_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/utils/audio_buffer_cpp_c_utils.h new file mode 100644 index 0000000..8c75d559 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/audio/utils/audio_buffer_cpp_c_utils.h
@@ -0,0 +1,40 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SUPPORT_C_TASK_AUDIO_AUDIO_BUFFER_CPP_C_UTILS_H_ +#define TENSORFLOW_LITE_SUPPORT_C_TASK_AUDIO_AUDIO_BUFFER_CPP_C_UTILS_H_ + +#include "tensorflow_lite_support/c/task/audio/core/audio_buffer.h" +#include "tensorflow_lite_support/cc/task/audio/core/audio_buffer.h" + +// Utils for Conversions between C and C++ AudioBuffer +// ----------------------------------------------------------------- +// Meant to be used with audio C apis. + +// Creates the C++ AudioBuffer from the C AudioBuffer +namespace tflite { +namespace task { +namespace audio { + +tflite::support::StatusOr<std::unique_ptr<tflite::task::audio::AudioBuffer>> +CreateCppAudioBuffer(const TfLiteAudioBuffer* audio_buffer); + +tflite::support::StatusOr<TfLiteAudioFormat*> CreateCAudioFormat( + tflite::support::StatusOr<tflite::task::audio::AudioBuffer::AudioFormat> + cpp_audio_format); + +} // namespace audio +} // namespace task +} // namespace tflite +#endif // TENSORFLOW_LITE_SUPPORT_C_TASK_AUDIO_AUDIO_BUFFER_CPP_C_UTILS_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/processor/classification_result.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/task/processor/classification_result.cc index 0c18c58..b7d7fab 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/processor/classification_result.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/processor/classification_result.cc
@@ -27,6 +27,7 @@ for (int head = 0; head < classification_result->size; ++head) { TfLiteClassifications classifications = classification_result->classifications[head]; + free(classifications.head_name); for (int rank = 0; rank < classifications.size; ++rank) { TfLiteCategoryDelete(&(classifications.categories[rank])); }
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/processor/classification_result.h b/third_party/tflite_support/src/tensorflow_lite_support/c/task/processor/classification_result.h index 1b73365a..69ee9149 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/processor/classification_result.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/processor/classification_result.h
@@ -29,6 +29,12 @@ // useful for multi-head models. int head_index; + // The name of the classifier head, which is the corresponding tensor metadata + // name. See + // https://github.com/tensorflow/tflite-support/blob/710e323265bfb71fdbdd72b3516e00cff15c0326/tensorflow_lite_support/metadata/metadata_schema.fbs#L545 + // This will always be NULL for vision APIs. + char* head_name; + // Number of predicted classes which can be used to traverse the array of // predicted classes. int size;
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/nl_classifier_common.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/nl_classifier_common.cc index c26ce05..0211c8d5 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/nl_classifier_common.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/nl_classifier_common.cc
@@ -21,7 +21,7 @@ extern "C" { #endif // __cplusplus -void NLClassifierCategoriesDelete(Categories* categories) { +void TfLiteNLClassifierCategoriesDelete(Categories* categories) { for (int i = 0; i < categories->size; i++) { // `strdup` obtains memory using `malloc` and the memory needs to be // released using `free`.
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/nl_classifier_common.h b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/nl_classifier_common.h index ed4d1c8..9f4851a3 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/nl_classifier_common.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/text/nl_classifier_common.h
@@ -34,7 +34,7 @@ Category* categories; } Categories; -void NLClassifierCategoriesDelete(Categories* categories); +void TfLiteNLClassifierCategoriesDelete(Categories* categories); #ifdef __cplusplus } // extern "C"
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_classifier.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_classifier.cc index 8d9aa85..183468a 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_classifier.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_classifier.cc
@@ -144,7 +144,8 @@ ++head) { const ClassificationsCpp& classifications = classification_result_cpp.classifications(head); - c_classifications[head].head_index = head; + c_classifications[head].head_index = classifications.head_index(); + c_classifications[head].head_name = nullptr; auto c_categories = new TfLiteCategory[classifications.classes_size()]; c_classifications->size = classifications.classes_size();
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/audio/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/audio/BUILD new file mode 100644 index 0000000..6b8e79dd --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/audio/BUILD
@@ -0,0 +1,34 @@ +load( + "@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", + "cc_test_with_tflite", +) + +package( + default_visibility = [ + "//visibility:private", + ], + licenses = ["notice"], # Apache 2.0 +) + +# bazel test tensorflow_lite_support/c/test/task/audio:audio_classifier_test +cc_test_with_tflite( + name = "audio_classifier_test", + srcs = ["audio_classifier_test.cc"], + data = [ + "//tensorflow_lite_support/cc/test/testdata/task/audio:test_audio_clips", + "//tensorflow_lite_support/cc/test/testdata/task/audio:test_models", + ], + tflite_deps = [ + "//tensorflow_lite_support/c/task/audio:audio_classifier", + "@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util", + ], + deps = [ + "//tensorflow_lite_support/c:common", + "//tensorflow_lite_support/c/task/audio/core:audio_buffer", + "//tensorflow_lite_support/c/task/processor:classification_result", + "//tensorflow_lite_support/cc/port:gtest_main", + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/task/audio/utils:wav_io", + "//tensorflow_lite_support/cc/test:test_utils", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/audio/audio_classifier_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/audio/audio_classifier_test.cc new file mode 100644 index 0000000..126784cf --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/audio/audio_classifier_test.cc
@@ -0,0 +1,210 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/c/task/audio/audio_classifier.h" + +#include <string.h> + +#include "tensorflow/lite/core/shims/cc/shims_test_util.h" +#include "tensorflow_lite_support/c/common.h" +#include "tensorflow_lite_support/c/task/audio/core/audio_buffer.h" +#include "tensorflow_lite_support/c/task/processor/classification_result.h" +#include "tensorflow_lite_support/cc/port/gmock.h" +#include "tensorflow_lite_support/cc/port/gtest.h" +#include "tensorflow_lite_support/cc/port/status_matchers.h" +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/audio/utils/wav_io.h" +#include "tensorflow_lite_support/cc/test/test_utils.h" + +namespace tflite { +namespace task { +namespace audio { +namespace { + +using ::testing::HasSubstr; +using ::tflite::support::StatusOr; +using ::tflite::task::JoinPath; + +constexpr char kTestDataDirectory[] = + "/tensorflow_lite_support/cc/test/testdata/task/" + "audio/"; +// Quantized model. +constexpr char kYamNetAudioClassifierWithMetadata[] = + "yamnet_audio_classifier_with_metadata.tflite"; + +StatusOr<TfLiteAudioBuffer> LoadAudioBufferFromFileNamed( + const std::string wav_file, + int buffer_size) { + std::string contents = + ReadFile(JoinPath("./" /*test src dir*/, kTestDataDirectory, wav_file)); + + uint32_t decoded_sample_count; + uint16_t decoded_channel_count; + uint32_t decoded_sample_rate; + std::vector<float> wav_data; + + absl::Status read_audio_file_status = DecodeLin16WaveAsFloatVector( + contents, &wav_data, &decoded_sample_count, &decoded_channel_count, + &decoded_sample_rate); + + if (decoded_sample_count > buffer_size) { + decoded_sample_count = buffer_size; + } + + if (!read_audio_file_status.ok()) { + return read_audio_file_status; + } + + float* c_wav_data = (float*)malloc(sizeof(float) * wav_data.size()); + if (!c_wav_data) { + exit(-1); + } + + memcpy(c_wav_data, wav_data.data(), sizeof(float) * wav_data.size()); + + TfLiteAudioBuffer audio_buffer = { + .format = {.channels = decoded_channel_count, + .sample_rate = static_cast<int>(decoded_sample_rate)}, + .data = c_wav_data, + .size = static_cast<int>(decoded_sample_count)}; + + return audio_buffer; +} + +void Verify(TfLiteClassificationResult* classification_result, + int expected_classifications_size) { + EXPECT_NE(classification_result, nullptr); + EXPECT_EQ(classification_result->size, expected_classifications_size); + EXPECT_NE(classification_result->classifications, nullptr); +} + +void Verify(TfLiteClassifications& classifications, + int expected_categories_size, + int expected_head_index, + char const* expected_head_name) { + EXPECT_EQ(classifications.size, expected_categories_size); + EXPECT_EQ(classifications.head_index, expected_head_index); + ASSERT_NE(classifications.head_name, nullptr); + if (expected_head_name) { + EXPECT_EQ(strcmp(classifications.head_name, expected_head_name), 0); + } + EXPECT_NE(classifications.categories, nullptr); +} + +void Verify(TfLiteCategory& category, + int expected_index, + char const* expected_label, + float expected_score) { + const float kPrecision = 1e-6; + EXPECT_EQ(category.index, expected_index); + EXPECT_NE(category.label, nullptr); + + if (category.label && expected_label) { + EXPECT_EQ(strcmp(category.label, expected_label), 0); + } + + EXPECT_EQ(category.display_name, nullptr); + EXPECT_NEAR(category.score, expected_score, kPrecision); +} + +void Verify(TfLiteSupportError* error, + TfLiteSupportErrorCode error_code, + char const* message) { + ASSERT_NE(error, nullptr); + EXPECT_EQ(error->code, kInvalidArgumentError); + EXPECT_NE(error->message, nullptr); + EXPECT_THAT(error->message, HasSubstr(message)); +} + +class AudioClassifierFromOptionsTest : public tflite_shims::testing::Test {}; + +TEST_F(AudioClassifierFromOptionsTest, FailsWithMissingModelPathAndError) { + TfLiteAudioClassifierOptions options = TfLiteAudioClassifierOptionsCreate(); + + TfLiteSupportError* error = nullptr; + TfLiteAudioClassifier* audio_classifier = + TfLiteAudioClassifierFromOptions(&options, &error); + + EXPECT_EQ(audio_classifier, nullptr); + if (audio_classifier) + TfLiteAudioClassifierDelete(audio_classifier); + + Verify(error, kInvalidArgumentError, + "INVALID_ARGUMENT: Missing mandatory `model_file` field in " + "`base_options`"); + + TfLiteSupportErrorDelete(error); +} + +TEST_F(AudioClassifierFromOptionsTest, SucceedsWithModelPath) { + std::string model_path = JoinPath("./" /*test src dir*/, kTestDataDirectory, + kYamNetAudioClassifierWithMetadata); + TfLiteAudioClassifierOptions options = TfLiteAudioClassifierOptionsCreate(); + options.base_options.model_file.file_path = model_path.data(); + TfLiteAudioClassifier* audio_classifier = + TfLiteAudioClassifierFromOptions(&options, nullptr); + + EXPECT_NE(audio_classifier, nullptr); + TfLiteAudioClassifierDelete(audio_classifier); +} + +class AudioClassifierClassifyTest : public tflite_shims::testing::Test { + protected: + void SetUp() override { + std::string model_path = JoinPath("./" /*test src dir*/, kTestDataDirectory, + kYamNetAudioClassifierWithMetadata); + + TfLiteAudioClassifierOptions options = TfLiteAudioClassifierOptionsCreate(); + options.base_options.model_file.file_path = model_path.data(); + audio_classifier = TfLiteAudioClassifierFromOptions(&options, nullptr); + ASSERT_NE(audio_classifier, nullptr); + } + + void TearDown() override { TfLiteAudioClassifierDelete(audio_classifier); } + TfLiteAudioClassifier* audio_classifier; +}; + +TEST_F(AudioClassifierClassifyTest, SucceedsWithAudioFile) { + int input_buffer_size = TfLiteAudioClassifierGetRequiredInputBufferSize( + audio_classifier, nullptr); + ASSERT_NE(input_buffer_size, -1); + + SUPPORT_ASSERT_OK_AND_ASSIGN( + TfLiteAudioBuffer audio_buffer, + LoadAudioBufferFromFileNamed("speech.wav", input_buffer_size)); + + TfLiteSupportError* classifyError = NULL; + TfLiteClassificationResult* classification_result = + TfLiteAudioClassifierClassify(audio_classifier, &audio_buffer, + &classifyError); + + TfLiteAudioBufferDeleteData(audio_buffer); + + Verify(classification_result, 1); + Verify(classification_result->classifications[0], 521, 0, "scores"); + Verify(classification_result->classifications[0].categories[0], 0, "Speech", + 0.917969); + Verify(classification_result->classifications[0].categories[1], 500, + "Inside, small room", 0.058594); + Verify(classification_result->classifications[0].categories[2], 494, + "Silence", 0.011719); + + TfLiteClassificationResultDelete(classification_result); +} + +} // namespace +} // namespace audio +} // namespace task +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/BUILD index fc723360..2bef0251 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/BUILD
@@ -28,8 +28,8 @@ "//tensorflow_lite_support/c/task/processor:classification_result", "//tensorflow_lite_support/c/task/vision/core:frame_buffer", "//tensorflow_lite_support/cc/port:gtest_main", + "//tensorflow_lite_support/cc/task/vision/utils:image_utils", "//tensorflow_lite_support/cc/test:test_utils", - "//tensorflow_lite_support/examples/task/vision/desktop/utils:image_utils", ], ) @@ -50,8 +50,8 @@ "//tensorflow_lite_support/c/task/processor:detection_result", "//tensorflow_lite_support/c/task/vision/core:frame_buffer", "//tensorflow_lite_support/cc/port:gtest_main", + "//tensorflow_lite_support/cc/task/vision/utils:image_utils", "//tensorflow_lite_support/cc/test:test_utils", - "//tensorflow_lite_support/examples/task/vision/desktop/utils:image_utils", ], ) @@ -71,7 +71,7 @@ "//tensorflow_lite_support/c/task/processor:segmentation_result", "//tensorflow_lite_support/c/task/vision/core:frame_buffer", "//tensorflow_lite_support/cc/port:gtest_main", + "//tensorflow_lite_support/cc/task/vision/utils:image_utils", "//tensorflow_lite_support/cc/test:test_utils", - "//tensorflow_lite_support/examples/task/vision/desktop/utils:image_utils", ], )
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/image_classifier_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/image_classifier_test.cc index b398b7ad..cce2fa63 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/image_classifier_test.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/image_classifier_test.cc
@@ -24,8 +24,8 @@ #include "tensorflow_lite_support/cc/port/gmock.h" #include "tensorflow_lite_support/cc/port/gtest.h" #include "tensorflow_lite_support/cc/port/status_matchers.h" +#include "tensorflow_lite_support/cc/task/vision/utils/image_utils.h" #include "tensorflow_lite_support/cc/test/test_utils.h" -#include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h" namespace tflite { namespace task {
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/image_segmenter_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/image_segmenter_test.cc index 81ade945..c03c15d 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/image_segmenter_test.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/image_segmenter_test.cc
@@ -26,8 +26,8 @@ #include "tensorflow_lite_support/cc/port/gmock.h" #include "tensorflow_lite_support/cc/port/gtest.h" #include "tensorflow_lite_support/cc/port/status_matchers.h" +#include "tensorflow_lite_support/cc/task/vision/utils/image_utils.h" #include "tensorflow_lite_support/cc/test/test_utils.h" -#include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h" namespace tflite { namespace task {
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/object_detector_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/object_detector_test.cc index 99cd003..78d78f5d 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/object_detector_test.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/object_detector_test.cc
@@ -24,8 +24,8 @@ #include "tensorflow_lite_support/cc/port/gmock.h" #include "tensorflow_lite_support/cc/port/gtest.h" #include "tensorflow_lite_support/cc/port/status_matchers.h" +#include "tensorflow_lite_support/cc/task/vision/utils/image_utils.h" #include "tensorflow_lite_support/cc/test/test_utils.h" -#include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h" namespace tflite { namespace task {
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/benchmark.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/benchmark.h new file mode 100644 index 0000000..74bc1a6 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/benchmark.h
@@ -0,0 +1,21 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_PORT_BENCHMARK_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_PORT_BENCHMARK_H_ + +#include "gtest/benchmark.h" + +#endif // TENSORFLOW_LITE_SUPPORT_CC_PORT_BENCHMARK_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/status_matchers.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/status_matchers.h new file mode 100644 index 0000000..6d96680 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/status_matchers.h
@@ -0,0 +1,55 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_STATUS_MATCHERS_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_STATUS_MATCHERS_H_ + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +#define SUPPORT_STATUS_MACROS_IMPL_CONCAT_INNER_(x, y) x##y +#define SUPPORT_STATUS_MACROS_IMPL_CONCAT_(x, y) \ + SUPPORT_STATUS_MACROS_IMPL_CONCAT_INNER_(x, y) + +#undef SUPPORT_ASSERT_OK +#define SUPPORT_ASSERT_OK(expr) \ + SUPPORT_ASSERT_OK_IMPL_( \ + SUPPORT_STATUS_MACROS_IMPL_CONCAT_(_status, __LINE__), expr) + +#define SUPPORT_ASSERT_OK_IMPL_(status, expr) \ + auto status = (expr); \ + ASSERT_TRUE(status.ok()); + +#undef SUPPORT_EXPECT_OK +#define SUPPORT_EXPECT_OK(expr) \ + SUPPORT_EXPECT_OK_IMPL_( \ + SUPPORT_STATUS_MACROS_IMPL_CONCAT_(_status, __LINE__), expr) + +#define SUPPORT_EXPECT_OK_IMPL_(status, expr) \ + auto status = (expr); \ + EXPECT_TRUE(status.ok()); + +#undef SUPPORT_ASSERT_OK_AND_ASSIGN +#define SUPPORT_ASSERT_OK_AND_ASSIGN(lhs, rexpr) \ + SUPPORT_ASSERT_OK_AND_ASSIGN_IMPL_( \ + SUPPORT_STATUS_MACROS_IMPL_CONCAT_(_status_or_value, __LINE__), lhs, \ + rexpr) + +#define SUPPORT_ASSERT_OK_AND_ASSIGN_IMPL_(statusor, lhs, rexpr) \ + auto statusor = (rexpr); \ + ASSERT_TRUE(statusor.ok()); \ + lhs = std::move(statusor.value()) + +#endif // TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_STATUS_MATCHERS_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/gmock.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/gmock.h new file mode 100644 index 0000000..5e4334d --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/gmock.h
@@ -0,0 +1,21 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_PORT_GMOCK_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_PORT_GMOCK_H_ + +#include "gmock/gmock.h" + +#endif // TENSORFLOW_LITE_SUPPORT_CC_PORT_GMOCK_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/gtest.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/gtest.h new file mode 100644 index 0000000..dbe2e5e6 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/gtest.h
@@ -0,0 +1,21 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_PORT_GTEST_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_PORT_GTEST_H_ + +#include "gtest/gtest.h" + +#endif // TENSORFLOW_LITE_SUPPORT_CC_PORT_GTEST_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/proto2.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/proto2.h new file mode 100644 index 0000000..3cde2ab8 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/proto2.h
@@ -0,0 +1,32 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SUPPORT_CC_PORT_PROTO_NS_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_PORT_PROTO_NS_H_ + +#include "google/protobuf/message_lite.h" +#include "google/protobuf/text_format.h" + +namespace tflite { +namespace support { +namespace proto { + +using TextFormat = ::google::protobuf::TextFormat; +using MessageLite = ::google::protobuf::MessageLite; + +} // namespace proto +} // namespace support +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_PORT_PROTO_NS_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/BUILD index fb1f02d7..d39aa2e 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/BUILD
@@ -186,6 +186,8 @@ "//tensorflow_lite_support/cc/task/processor/proto:embedding_options_cc_proto", "//tensorflow_lite_support/cc/task/processor/proto:search_options_cc_proto", "//tensorflow_lite_support/cc/task/processor/proto:search_result_cc_proto", + "//tensorflow_lite_support/metadata:metadata_schema_cc", + "//tensorflow_lite_support/metadata/cc:metadata_extractor", "//tensorflow_lite_support/scann_ondevice/cc:index", "//tensorflow_lite_support/scann_ondevice/cc/core:partitioner", "//tensorflow_lite_support/scann_ondevice/cc/core:processor",
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/proto/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/proto/BUILD index 9b64c859..61bd5b6 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/proto/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/proto/BUILD
@@ -44,6 +44,9 @@ srcs = ["classifications.proto"], api_version = 2, proto_deps = [":classifications_proto"], + py_proto_deps = [ + ":class_py_pb2", + ], ) proto_library( @@ -213,6 +216,16 @@ ], ) +support_py_proto_library( + name = "search_options_py_pb2", + srcs = ["search_options.proto"], + api_version = 2, + proto_deps = [":search_options_proto"], + py_proto_deps = [ + "//tensorflow_lite_support/cc/task/core/proto:external_file_py_pb2", + ], +) + proto_library( name = "search_result_proto", srcs = ["search_result.proto"], @@ -222,3 +235,10 @@ name = "search_result_cc_proto", deps = [":search_result_proto"], ) + +support_py_proto_library( + name = "search_result_py_pb2", + srcs = ["search_result.proto"], + api_version = 2, + proto_deps = [":search_result_proto"], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/proto/search_options.proto b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/proto/search_options.proto index d50c60ba..6aacb740 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/proto/search_options.proto +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/proto/search_options.proto
@@ -25,9 +25,12 @@ // Options for search processor. // Next Id: 4 message SearchOptions { - // The index file to search into. + // The index file to search into. Mandatory only if the index is not attached + // to the output tensor metadata as an AssociatedFile with type + // SCANN_INDEX_FILE. + // Note that in case both are provided, this field takes precedence. optional core.ExternalFile index_file = 1; - // Number of nearest neighbor results to return. - optional int32 num_results = 2 [default = 5]; + // Maximum number of nearest neighbor results to return. + optional int32 max_results = 2 [default = 5]; }
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/search_postprocessor.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/search_postprocessor.cc new file mode 100644 index 0000000..a2fa1f8 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/search_postprocessor.cc
@@ -0,0 +1,390 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/cc/task/processor/search_postprocessor.h" + +#include <algorithm> +#include <cstdint> +#include <initializer_list> +#include <limits> +#include <memory> +#include <vector> + +#include "Eigen/Core" // from @eigen +#include "absl/memory/memory.h" // from @com_google_absl +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/str_format.h" // from @com_google_absl +#include "absl/strings/string_view.h" // from @com_google_absl +#include "absl/types/span.h" // from @com_google_absl +#include "tensorflow_lite_support/cc/common.h" +#include "tensorflow_lite_support/cc/port/status_macros.h" +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/core/external_file_handler.h" +#include "tensorflow_lite_support/cc/task/core/tflite_engine.h" +#include "tensorflow_lite_support/cc/task/processor/embedding_postprocessor.h" +#include "tensorflow_lite_support/cc/task/processor/proto/embedding.pb.h" +#include "tensorflow_lite_support/cc/task/processor/proto/embedding_options.pb.h" +#include "tensorflow_lite_support/cc/task/processor/proto/search_options.pb.h" +#include "tensorflow_lite_support/cc/task/processor/proto/search_result.pb.h" +#include "tensorflow_lite_support/metadata/cc/metadata_extractor.h" +#include "tensorflow_lite_support/metadata/metadata_schema_generated.h" +#include "tensorflow_lite_support/scann_ondevice/cc/core/partitioner.h" +#include "tensorflow_lite_support/scann_ondevice/cc/core/processor.h" +#include "tensorflow_lite_support/scann_ondevice/cc/core/searcher.h" +#include "tensorflow_lite_support/scann_ondevice/cc/core/serialized_searcher.pb.h" +#include "tensorflow_lite_support/scann_ondevice/cc/core/top_n_amortized_constant.h" +#include "tensorflow_lite_support/scann_ondevice/cc/index.h" +#include "tensorflow_lite_support/scann_ondevice/proto/index_config.pb.h" + +namespace tflite { +namespace task { +namespace processor { + +namespace { + +constexpr int kNoNeighborId = -1; + +using ::tflite::TensorMetadata; +using ::tflite::metadata::ModelMetadataExtractor; +using ::tflite::scann_ondevice::Index; +using ::tflite::scann_ondevice::IndexConfig; +using ::tflite::scann_ondevice::core::AsymmetricHashFindNeighbors; +using ::tflite::scann_ondevice::core::DistanceMeasure; +using ::tflite::scann_ondevice::core::FloatFindNeighbors; +using ::tflite::scann_ondevice::core::QueryInfo; +using ::tflite::scann_ondevice::core::ScannOnDeviceConfig; +using ::tflite::scann_ondevice::core::TopN; +using ::tflite::support::CreateStatusWithPayload; +using ::tflite::support::StatusOr; +using ::tflite::support::TfLiteSupportStatus; +using ::tflite::task::core::ExternalFileHandler; +using ::tflite::task::core::TfLiteEngine; +using ::tflite::task::processor::Embedding; + +using Matrix8u = + Eigen::Matrix<uint8_t, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>; + +absl::StatusOr<std::unique_ptr<EmbeddingPostprocessor>> +CreateEmbeddingPostprocessor(TfLiteEngine* engine, + const std::initializer_list<int> output_indices, + std::unique_ptr<EmbeddingOptions> options) { + if (options->quantize()) { + // ScaNN only supports searching from float embeddings. + return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument, + "Setting EmbeddingOptions.quantize = true " + "is not allowed in searchers.", + TfLiteSupportStatus::kInvalidArgumentError); + } + return EmbeddingPostprocessor::Create(engine, output_indices, + std::move(options)); +} + +absl::Status SanityCheckOptions(const SearchOptions& options) { + if (options.max_results() < 1) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat("SearchOptions.max_results must be > 0, found %d.", + options.max_results()), + TfLiteSupportStatus::kInvalidArgumentError); + } + return absl::OkStatus(); +} + +absl::Status SanityCheckIndexConfig(const IndexConfig& config) { + switch (config.embedding_type()) { + case IndexConfig::UNSPECIFIED: + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "Invalid IndexConfig: embedding_type must not be left UNSPECIFIED.", + TfLiteSupportStatus::kInvalidArgumentError); + case IndexConfig::FLOAT: + if (config.scann_config().has_indexer()) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "Invalid IndexConfig: embedding_type is set to FLOAT but ScaNN " + "config specifies a product quantization codebook.", + TfLiteSupportStatus::kInvalidArgumentError); + } + break; + case IndexConfig::UINT8: + if (!config.scann_config().has_indexer()) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "Invalid IndexConfig: embedding_type is set to UINT8 but ScaNN " + "config doesn't specify a product quantization codebook.", + TfLiteSupportStatus::kInvalidArgumentError); + } + break; + default: + return CreateStatusWithPayload( + absl::StatusCode::kInternal, + "Invalid IndexConfig: unexpected value for embedding_type.", + TfLiteSupportStatus::kError); + } + return absl::OkStatus(); +} + +StatusOr<absl::string_view> GetIndexFileContentFromMetadata( + const ModelMetadataExtractor& metadata_extractor, + const TensorMetadata& tensor_metadata) { + auto index_file_name = ModelMetadataExtractor::FindFirstAssociatedFileName( + tensor_metadata, tflite::AssociatedFileType_SCANN_INDEX_FILE); + if (index_file_name.empty()) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "Unable to find index file: SearchOptions.index_file is not set and no " + "AssociatedFile with type SCANN_INDEX_FILE could be found in the " + "output tensor metadata.", + TfLiteSupportStatus::kMetadataAssociatedFileNotFoundError); + } + return metadata_extractor.GetAssociatedFile(index_file_name); +} + +absl::StatusOr<DistanceMeasure> GetDistanceMeasure( + const ScannOnDeviceConfig& config) { + DistanceMeasure measure = config.query_distance(); + if (measure == tflite::scann_ondevice::core::UNSPECIFIED) { + if (config.has_indexer() && config.indexer().has_asymmetric_hashing()) { + measure = config.indexer().asymmetric_hashing().query_distance(); + } else if (config.has_partitioner()) { + measure = config.partitioner().query_distance(); + } else { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "ScaNN config does not provide mandatory DistanceMeasure.", + TfLiteSupportStatus::kInvalidArgumentError); + } + + if (measure == tflite::scann_ondevice::core::UNSPECIFIED) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "UNSPECIFIED is not a valid value for ScaNN config DistanceMeasure.", + TfLiteSupportStatus::kInvalidArgumentError); + } + + // Make sure the query distance in different places are consistent. + if (config.has_partitioner()) { + DistanceMeasure partitioner_measure = + config.partitioner().query_distance(); + if (measure != partitioner_measure) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat("DistanceMeasure %s is different from " + "DistanceMeasure %s found in partitioner config.", + DistanceMeasure_Name(measure), + DistanceMeasure_Name(partitioner_measure)), + TfLiteSupportStatus::kInvalidArgumentError); + } + } + } + return measure; +} + +absl::Status ConvertEmbeddingToEigenMatrix(const Embedding& embedding, + Eigen::MatrixXf* matrix) { + if (embedding.feature_vector().value_float().empty()) { + // This should be caught upstream at EmbeddingPostprocessor creation. + return CreateStatusWithPayload(absl::StatusCode::kInternal, + "Float query embedding is empty.", + TfLiteSupportStatus::kError); + } + Eigen::Map<const Eigen::VectorXf> query_ptr( + embedding.feature_vector().value_float().data(), + embedding.feature_vector().value_float().size()); + matrix->resize(embedding.feature_vector().value_float().size(), 1); + matrix->col(0) = query_ptr; + return absl::OkStatus(); +} + +} // namespace + +/* static */ +StatusOr<std::unique_ptr<SearchPostprocessor>> SearchPostprocessor::Create( + TfLiteEngine* engine, + int output_index, + std::unique_ptr<SearchOptions> search_options, + std::unique_ptr<EmbeddingOptions> embedding_options) { + ASSIGN_OR_RETURN(auto embedding_postprocessor, + CreateEmbeddingPostprocessor(engine, {output_index}, + std::move(embedding_options))); + + ASSIGN_OR_RETURN(auto search_processor, + Processor::Create<SearchPostprocessor>( + /* num_expected_tensors =*/1, engine, {output_index}, + /* requires_metadata =*/false)); + + RETURN_IF_ERROR(search_processor->Init(std::move(embedding_postprocessor), + std::move(search_options))); + return search_processor; +} + +StatusOr<SearchResult> SearchPostprocessor::Postprocess() { + // Extract embedding. + Embedding embedding; + RETURN_IF_ERROR(embedding_postprocessor_->Postprocess(&embedding)); + // Convert embedding to Eigen matrix, as expected by ScaNN. + Eigen::MatrixXf query; + RETURN_IF_ERROR(ConvertEmbeddingToEigenMatrix(embedding, &query)); + + // Identify partitions to search. + std::vector<std::vector<int>> leaves_to_search( + 1, std::vector<int>(num_leaves_to_search_, -1)); + if (!partitioner_->Partition(query, &leaves_to_search)) { + return CreateStatusWithPayload(absl::StatusCode::kInternal, + "Partitioning failed.", + TfLiteSupportStatus::kError); + } + + // Prepare search results. + std::vector<TopN> top_n; + top_n.emplace_back( + options_->max_results(), + std::make_pair(std::numeric_limits<float>::max(), kNoNeighborId)); + // Perform search. + if (quantizer_) { + RETURN_IF_ERROR( + QuantizedSearch(query, leaves_to_search[0], absl::MakeSpan(top_n))); + } else { + RETURN_IF_ERROR( + LinearSearch(query, leaves_to_search[0], absl::MakeSpan(top_n))); + } + + // Build results. + SearchResult search_result; + for (const auto& [distance, id] : top_n[0].Take()) { + if (id == kNoNeighborId) { + break; + } + ASSIGN_OR_RETURN(auto metadata, index_->GetMetadataAtIndex(id)); + NearestNeighbor* nearest_neighbor = search_result.add_nearest_neighbors(); + nearest_neighbor->set_distance(distance); + nearest_neighbor->set_metadata(std::string(metadata)); + } + return search_result; +} + +StatusOr<absl::string_view> SearchPostprocessor::GetUserInfo() { + return index_->GetUserInfo(); +} + +absl::Status SearchPostprocessor::Init( + std::unique_ptr<EmbeddingPostprocessor> embedding_postprocessor, + std::unique_ptr<SearchOptions> options) { + embedding_postprocessor_ = std::move(embedding_postprocessor); + RETURN_IF_ERROR(SanityCheckOptions(*options)); + options_ = std::move(options); + + // Initialize index. + absl::string_view index_file_content; + if (options_->has_index_file()) { + ASSIGN_OR_RETURN( + index_file_handler_, + ExternalFileHandler::CreateFromExternalFile(&options_->index_file())); + index_file_content = index_file_handler_->GetFileContent(); + } else { + ASSIGN_OR_RETURN(index_file_content, + GetIndexFileContentFromMetadata(*GetMetadataExtractor(), + *GetTensorMetadata())); + } + ASSIGN_OR_RETURN(index_, + Index::CreateFromIndexBuffer(index_file_content.data(), + index_file_content.size())); + ASSIGN_OR_RETURN(index_config_, index_->GetIndexConfig()); + RETURN_IF_ERROR(SanityCheckIndexConfig(index_config_)); + // Get distance measure once and for all. + ASSIGN_OR_RETURN(distance_measure_, + GetDistanceMeasure(index_config_.scann_config())); + + // Initialize partitioner. + if (index_config_.scann_config().has_partitioner()) { + partitioner_ = tflite::scann_ondevice::core::Partitioner::Create( + index_config_.scann_config().partitioner()); + num_leaves_to_search_ = std::min( + static_cast<int>(ceilf( + partitioner_->NumPartitions() * + index_config_.scann_config().partitioner().search_fraction())), + partitioner_->NumPartitions()); + } else { + partitioner_ = + absl::make_unique<tflite::scann_ondevice::core::NoOpPartitioner>(); + num_leaves_to_search_ = partitioner_->NumPartitions(); + } + + // Initialize product quantizer if needed. + if (index_config_.scann_config().has_indexer()) { + quantizer_ = tflite::scann_ondevice::core::AsymmetricHashQuerier::Create( + index_config_.scann_config().indexer().asymmetric_hashing()); + } + + return absl::OkStatus(); +} + +absl::Status SearchPostprocessor::QuantizedSearch( + Eigen::Ref<Eigen::MatrixXf> query, + std::vector<int> leaves_to_search, + absl::Span<TopN> top_n) { + int dim = index_config_.embedding_dim(); + // Prepare QueryInfo used for all leaves. + QueryInfo query_info; + if (!quantizer_->Process(query, &query_info)) { + return CreateStatusWithPayload(absl::StatusCode::kInternal, + "Query quantization failed.", + TfLiteSupportStatus::kError); + } + for (int leaf_id : leaves_to_search) { + // Load partition into Eigen matrix. + ASSIGN_OR_RETURN(auto partition, index_->GetPartitionAtIndex(leaf_id)); + int partition_size = partition.size() / dim; + Eigen::Map<const Matrix8u> database( + reinterpret_cast<const uint8_t*>(partition.data()), dim, + partition_size); + // Perform search. + int global_offset = index_config_.global_partition_offsets(leaf_id); + if (!AsymmetricHashFindNeighbors(query_info, database, global_offset, + top_n)) { + return CreateStatusWithPayload(absl::StatusCode::kInternal, + "Nearest neighbor search failed.", + TfLiteSupportStatus::kError); + } + } + return absl::OkStatus(); +} + +absl::Status SearchPostprocessor::LinearSearch( + Eigen::Ref<Eigen::MatrixXf> query, + std::vector<int> leaves_to_search, + absl::Span<TopN> top_n) { + int dim = index_config_.embedding_dim(); + for (int leaf_id : leaves_to_search) { + // Load partition into Eigen matrix. + ASSIGN_OR_RETURN(auto partition, index_->GetPartitionAtIndex(leaf_id)); + int partition_size = partition.size() / (dim * sizeof(float)); + Eigen::Map<const Eigen::MatrixXf> database( + reinterpret_cast<const float*>(partition.data()), dim, partition_size); + // Perform search. + int global_offset = index_config_.global_partition_offsets(leaf_id); + if (!FloatFindNeighbors(query, database, global_offset, distance_measure_, + top_n)) { + return CreateStatusWithPayload(absl::StatusCode::kInternal, + "Nearest neighbor search failed.", + TfLiteSupportStatus::kError); + } + } + return absl::OkStatus(); +} + +} // namespace processor +} // namespace task +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/search_postprocessor.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/search_postprocessor.h new file mode 100644 index 0000000..d79bc85 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/processor/search_postprocessor.h
@@ -0,0 +1,112 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_PROCESSOR_SEARCH_POSTPROCESSOR_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_PROCESSOR_SEARCH_POSTPROCESSOR_H_ + +#include <cstdint> +#include <initializer_list> +#include <memory> +#include <vector> + +#include "Eigen/Core" // from @eigen +#include "absl/strings/string_view.h" // from @com_google_absl +#include "absl/types/span.h" // from @com_google_absl +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/core/external_file_handler.h" +#include "tensorflow_lite_support/cc/task/core/tflite_engine.h" +#include "tensorflow_lite_support/cc/task/processor/embedding_postprocessor.h" +#include "tensorflow_lite_support/cc/task/processor/processor.h" +#include "tensorflow_lite_support/cc/task/processor/proto/embedding_options.pb.h" +#include "tensorflow_lite_support/cc/task/processor/proto/search_options.pb.h" +#include "tensorflow_lite_support/cc/task/processor/proto/search_result.pb.h" +#include "tensorflow_lite_support/scann_ondevice/cc/core/partitioner.h" +#include "tensorflow_lite_support/scann_ondevice/cc/core/processor.h" +#include "tensorflow_lite_support/scann_ondevice/cc/core/searcher.h" +#include "tensorflow_lite_support/scann_ondevice/cc/core/serialized_searcher.pb.h" +#include "tensorflow_lite_support/scann_ondevice/cc/core/top_n_amortized_constant.h" +#include "tensorflow_lite_support/scann_ondevice/cc/index.h" +#include "tensorflow_lite_support/scann_ondevice/proto/index_config.pb.h" + +namespace tflite { +namespace task { +namespace processor { + +// Postprocessor in charge of performing embedding extraction followed by +// nearest-neighbor search. +// +// This postprocessor works with the following output tensor: +// (kTfLiteUInt8/kTfLiteFloat32) +// - `N` components corresponding to the `N` dimensions of the returned +// feature vector for this output layer. +// - Either 2 or 4 dimensions, i.e. `[1 x N]` or `[1 x 1 x 1 x N]`. +class SearchPostprocessor : public Postprocessor { + public: + static tflite::support::StatusOr<std::unique_ptr<SearchPostprocessor>> Create( + tflite::task::core::TfLiteEngine* engine, + int output_index, + std::unique_ptr<SearchOptions> search_options, + std::unique_ptr<EmbeddingOptions> embedding_options = + std::make_unique<EmbeddingOptions>()); + + // Converts the tensor outputs to embeddings, then performs a nearest-neighbor + // search in the index. + tflite::support::StatusOr<SearchResult> Postprocess(); + + // Provides access to the opaque user info stored in the index file (if any), + // in raw binary form. Returns an empty string if the index doesn't contain + // user info. + tflite::support::StatusOr<absl::string_view> GetUserInfo(); + + private: + using Postprocessor::Postprocessor; + + absl::Status Init( + std::unique_ptr<EmbeddingPostprocessor> embedding_postprocessor, + std::unique_ptr<SearchOptions> options); + + absl::Status QuantizedSearch( + Eigen::Ref<Eigen::MatrixXf> query, + std::vector<int> leaves_to_search, + absl::Span<tflite::scann_ondevice::core::TopN> top_n); + absl::Status LinearSearch( + Eigen::Ref<Eigen::MatrixXf> query, + std::vector<int> leaves_to_search, + absl::Span<tflite::scann_ondevice::core::TopN> top_n); + + std::unique_ptr<SearchOptions> options_; + + // Encapsulated EmbeddingPostprocessor converting raw tensors to embeddings. + std::unique_ptr<EmbeddingPostprocessor> embedding_postprocessor_; + + // Index management. + std::unique_ptr<tflite::task::core::ExternalFileHandler> index_file_handler_; + std::unique_ptr<tflite::scann_ondevice::Index> index_; + tflite::scann_ondevice::IndexConfig index_config_; + + // ScaNN management. + int num_leaves_to_search_; + tflite::scann_ondevice::core::DistanceMeasure distance_measure_; + std::unique_ptr<tflite::scann_ondevice::core::PartitionerInterface> + partitioner_; + std::shared_ptr<tflite::scann_ondevice::core::AsymmetricHashQuerier> + quantizer_; +}; + +} // namespace processor +} // namespace task +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_PROCESSOR_SEARCH_POSTPROCESSOR_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/BUILD index 72e3961..0ffb4cd 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/BUILD
@@ -121,3 +121,24 @@ "@org_tensorflow//tensorflow/lite/c:common", ], ) + +cc_library( + name = "image_utils", + srcs = ["image_utils.cc"], + hdrs = ["image_utils.h"], + visibility = [ + "//tensorflow_lite_support:internal", + ], + deps = [ + "//tensorflow_lite_support/cc/port:integral_types", + "//tensorflow_lite_support/cc/port:status_macros", + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/task/vision/core:frame_buffer", + "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_common_utils", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@stblib//:stb_image", + "@stblib//:stb_image_write", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.cc index cea7ef3..9a5b9616 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.cc
@@ -299,11 +299,6 @@ return absl::InvalidArgumentError( "Grayscale format does not convert to other formats."); case FrameBuffer::Format::kRGB: - if (to_format == FrameBuffer::Format::kRGBA) { - return absl::InvalidArgumentError( - "RGB format does not convert to RGBA"); - } - return absl::OkStatus(); case FrameBuffer::Format::kRGBA: case FrameBuffer::Format::kNV12: case FrameBuffer::Format::kNV21:
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/image_utils.cc similarity index 97% rename from third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.cc rename to third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/image_utils.cc index d5c0c58..d5b277a 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/image_utils.cc
@@ -12,7 +12,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h" +#include "tensorflow_lite_support/cc/task/vision/utils/image_utils.h" #include <cstdlib> #include <cstring>
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/image_utils.h similarity index 100% rename from third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h rename to third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/image_utils.h
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.cc index a00c8223..a0ee2da 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.cc
@@ -24,6 +24,7 @@ #include "absl/strings/str_cat.h" // from @com_google_absl #include "absl/strings/str_format.h" // from @com_google_absl #include "libyuv.h" // from @libyuv +#include "libyuv/convert_argb.h" // from @libyuv #include "tensorflow_lite_support/cc/common.h" #include "tensorflow_lite_support/cc/port/integral_types.h" #include "tensorflow_lite_support/cc/port/status_macros.h" @@ -553,6 +554,20 @@ return ConvertFromYv(*yuv_frame_buffer, output_buffer); } return absl::OkStatus(); + } else if (output_buffer->format() == FrameBuffer::Format::kRGBA) { + // RGB24 is BGR in memory and ARGB is BGRA in memory. The additional of the + // alpha channel will not impact the RGB ordering. + int ret = libyuv::RGB24ToARGB( + buffer.plane(0).buffer, buffer.plane(0).stride.row_stride_bytes, + const_cast<uint8*>(output_buffer->plane(0).buffer), + output_buffer->plane(0).stride.row_stride_bytes, + buffer.dimension().width, buffer.dimension().height); + if (ret != 0) { + return CreateStatusWithPayload( + StatusCode::kInternal, "Libyuv RAWToARGB operation failed.", + TfLiteSupportStatus::kImageProcessingBackendError); + } + return absl::OkStatus(); } return CreateStatusWithPayload( StatusCode::kInternal,
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/processor/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/processor/BUILD index 7376ad34..6e9955fa 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/processor/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/processor/BUILD
@@ -25,8 +25,8 @@ "//tensorflow_lite_support/cc/port:gtest_main", "//tensorflow_lite_support/cc/task/core:task_utils", "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_common_utils", + "//tensorflow_lite_support/cc/task/vision/utils:image_utils", "//tensorflow_lite_support/cc/test:test_utils", - "//tensorflow_lite_support/examples/task/vision/desktop/utils:image_utils", "@com_google_absl//absl/status", ], )
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/processor/image_preprocessor_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/processor/image_preprocessor_test.cc index 9ae9435..ef0e783 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/processor/image_preprocessor_test.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/processor/image_preprocessor_test.cc
@@ -24,8 +24,8 @@ #include "tensorflow_lite_support/cc/port/status_matchers.h" #include "tensorflow_lite_support/cc/task/core/task_utils.h" #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h" +#include "tensorflow_lite_support/cc/task/vision/utils/image_utils.h" #include "tensorflow_lite_support/cc/test/test_utils.h" -#include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h" namespace tflite { namespace task {
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/BUILD index c2d80d4..2ea185b3 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/BUILD
@@ -83,3 +83,34 @@ "@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util", ], ) + +cc_test( + name = "text_searcher_test", + srcs = ["text_searcher_test.cc"], + data = [ + "//tensorflow_lite_support/cc/test/testdata/task/text:mobilebert_embedding_with_metadata", + "//tensorflow_lite_support/cc/test/testdata/task/text:regex_embedding_with_metadata", + "//tensorflow_lite_support/cc/test/testdata/task/text:test_indices", + "//tensorflow_lite_support/cc/test/testdata/task/text:test_searchers", + "//tensorflow_lite_support/cc/test/testdata/task/text:universal_sentence_encoder_qa", + ], + deps = [ + "//tensorflow_lite_support/cc:common", + "//tensorflow_lite_support/cc/port:gtest_main", + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/task/core/proto:base_options_cc_proto", + "//tensorflow_lite_support/cc/task/processor/proto:embedding_options_cc_proto", + "//tensorflow_lite_support/cc/task/processor/proto:search_options_cc_proto", + "//tensorflow_lite_support/cc/task/processor/proto:search_result_cc_proto", + "//tensorflow_lite_support/cc/task/text:text_searcher", + "//tensorflow_lite_support/cc/task/text/proto:text_searcher_options_cc_proto", + "//tensorflow_lite_support/cc/test:test_utils", + "//tensorflow_lite_support/examples/task/text/desktop:universal_sentence_encoder_qa_op_resolver", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + "@org_tensorflow//tensorflow/lite/core/api:op_resolver", + "@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/text_embedder_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/text_embedder_test.cc index 931a44e..b097813 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/text_embedder_test.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/text_embedder_test.cc
@@ -126,6 +126,38 @@ absl::StrCat(support::TfLiteSupportStatus::kInvalidArgumentError)))); } +TEST(EmbedTest, SucceedsWithMobileBertModel) { + TextEmbedderOptions options = GetBasicOptions(kMobileBert); + // No Embedding options means all head get a default option. + SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextEmbedder> text_embedder, + TextEmbedder::CreateFromOptions(options)); + + SUPPORT_ASSERT_OK_AND_ASSIGN( + auto result0, + text_embedder->Embed("it's a charming and often affecting journey")); + EXPECT_EQ(result0.embeddings_size(), 1); + EXPECT_EQ(result0.embeddings(0).feature_vector().value_float_size(), 512); + + EXPECT_NEAR(result0.embeddings(0).feature_vector().value_float(0), 19.9016f, + kValueDiffTolerance); + + SUPPORT_ASSERT_OK_AND_ASSIGN( + auto result1, text_embedder->Embed("what a great and fantastic trip")); + EXPECT_EQ(result1.embeddings_size(), 1); + EXPECT_EQ(result1.embeddings(0).feature_vector().value_float_size(), 512); + + EXPECT_NEAR(result1.embeddings(0).feature_vector().value_float(0), 22.626251f, + kValueDiffTolerance); + + // Check cosine similarity. + SUPPORT_ASSERT_OK_AND_ASSIGN( + double similarity, + TextEmbedder::CosineSimilarity(result0.embeddings(0).feature_vector(), + result1.embeddings(0).feature_vector())); + double expected_similarity = 0.969514; + EXPECT_NEAR(similarity, expected_similarity, kSimilarityTolerancy); +} + TEST(EmbedTest, SucceedsWithRegexModel) { TextEmbedderOptions options = GetBasicOptions(kRegexOneEmbeddingModel); // No Embedding options means all head get a default option.
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/text_searcher_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/text_searcher_test.cc new file mode 100644 index 0000000..f38615c --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/text/text_searcher_test.cc
@@ -0,0 +1,425 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/cc/task/text/text_searcher.h" + +#include <memory> +#include <string> + +#include "absl/flags/flag.h" // from @com_google_absl +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/cord.h" // from @com_google_absl +#include "absl/strings/str_cat.h" // from @com_google_absl +#include "tensorflow/lite/core/api/op_resolver.h" +#include "tensorflow/lite/core/shims/cc/shims_test_util.h" +#include "tensorflow_lite_support/cc/common.h" +#include "tensorflow_lite_support/cc/port/gmock.h" +#include "tensorflow_lite_support/cc/port/gtest.h" +#include "tensorflow_lite_support/cc/port/status_matchers.h" +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/core/proto/base_options.pb.h" +#include "tensorflow_lite_support/cc/task/processor/proto/embedding_options.pb.h" +#include "tensorflow_lite_support/cc/task/processor/proto/search_options.pb.h" +#include "tensorflow_lite_support/cc/task/processor/proto/search_result.pb.h" +#include "tensorflow_lite_support/cc/task/text/proto/text_searcher_options.pb.h" +#include "tensorflow_lite_support/cc/test/test_utils.h" +#include "tensorflow_lite_support/examples/task/text/desktop/universal_sentence_encoder_qa_op_resolver.h" + +namespace tflite { +namespace task { +namespace text { +namespace { + +using ::testing::HasSubstr; +using ::testing::Optional; +using ::testing::TestParamInfo; +using ::testing::TestWithParam; +using ::testing::Values; +using ::tflite::support::kTfLiteSupportPayload; +using ::tflite::support::StatusOr; +using ::tflite::support::TfLiteSupportStatus; +using ::tflite::task::processor::NearestNeighbor; +using ::tflite::task::processor::SearchResult; + +constexpr char kTestDataDirectory[] = + "/tensorflow_lite_support/cc/test/testdata/task/text/"; +constexpr char kMobileBertEmbedder[] = + "mobilebert_embedding_with_metadata.tflite"; +constexpr char kMobileBertIndex[] = "mobilebert_index.ldb"; +constexpr char kMobileBertSearcher[] = "mobilebert_searcher.tflite"; +constexpr char kRegexEmbedder[] = "regex_one_embedding_with_metadata.tflite"; +constexpr char kRegexIndex[] = "regex_index.ldb"; +constexpr char kRegexSearcher[] = "regex_searcher.tflite"; +constexpr char kUSEEmbedder[] = + "universal_sentence_encoder_qa_with_metadata.tflite"; +constexpr char kUSEIndex[] = "universal_sentence_encoder_index.ldb"; +constexpr char kUSESearcher[] = "universal_sentence_encoder_searcher.tflite"; + +// Checks that the two provided `SearchResult` protos are equal, with a +// tolerancy on floating-point scores to account for numerical instabilities. +void ExpectApproximatelyEqual(const SearchResult& actual, + const SearchResult& expected) { + const float kPrecision = 1e-5; + EXPECT_EQ(actual.nearest_neighbors_size(), expected.nearest_neighbors_size()); + for (int i = 0; i < actual.nearest_neighbors_size(); ++i) { + const NearestNeighbor& a = actual.nearest_neighbors(i); + const NearestNeighbor& b = expected.nearest_neighbors(i); + EXPECT_EQ(a.metadata(), b.metadata()); + EXPECT_NEAR(a.distance(), b.distance(), kPrecision); + } +} + +std::unique_ptr<tflite::OpResolver> GetOpResolver( + bool is_universal_sentence_encoder) { + if (is_universal_sentence_encoder) { + return CreateQACustomOpResolver(); + } else { + return absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>(); + } +} + +struct CreateFromOptionsParams { + std::string name; + std::string embedder_model_name; + std::string searcher_model_name; + bool is_universal_sentence_encoder; + std::string index_name; +}; + +class CreateFromOptionsTest : public TestWithParam<CreateFromOptionsParams> {}; + +TEST_P(CreateFromOptionsTest, SucceedsWithStandaloneIndex) { + TextSearcherOptions options; + options.mutable_base_options()->mutable_model_file()->set_file_name( + JoinPath("./" /*test src dir*/, kTestDataDirectory, + GetParam().embedder_model_name)); + options.mutable_embedding_options()->set_l2_normalize(true); + options.mutable_search_options()->mutable_index_file()->set_file_name( + JoinPath("./" /*test src dir*/, kTestDataDirectory, + GetParam().index_name)); + + SUPPORT_ASSERT_OK(TextSearcher::CreateFromOptions( + options, GetOpResolver(GetParam().is_universal_sentence_encoder))); +} + +TEST_P(CreateFromOptionsTest, SucceedsWithMetadataIndex) { + TextSearcherOptions options; + options.mutable_base_options()->mutable_model_file()->set_file_name( + JoinPath("./" /*test src dir*/, kTestDataDirectory, + GetParam().searcher_model_name)); + options.mutable_embedding_options()->set_l2_normalize(true); + + SUPPORT_ASSERT_OK(TextSearcher::CreateFromOptions( + options, GetOpResolver(GetParam().is_universal_sentence_encoder))); +} + +TEST_P(CreateFromOptionsTest, FailsWithMissingModel) { + TextSearcherOptions options; + options.mutable_embedding_options()->set_l2_normalize(true); + options.mutable_search_options()->mutable_index_file()->set_file_name( + JoinPath("./" /*test src dir*/, kTestDataDirectory, + GetParam().index_name)); + + StatusOr<std::unique_ptr<TextSearcher>> image_searcher_or = + TextSearcher::CreateFromOptions( + options, GetOpResolver(GetParam().is_universal_sentence_encoder)); + + EXPECT_EQ(image_searcher_or.status().code(), + absl::StatusCode::kInvalidArgument); + EXPECT_THAT( + image_searcher_or.status().message(), + HasSubstr("Missing mandatory `model_file` field in `base_options`")); + EXPECT_THAT(image_searcher_or.status().GetPayload(kTfLiteSupportPayload), + Optional(absl::Cord( + absl::StrCat(TfLiteSupportStatus::kInvalidArgumentError)))); +} + +TEST_P(CreateFromOptionsTest, FailsWithMissingIndex) { + TextSearcherOptions options; + options.mutable_base_options()->mutable_model_file()->set_file_name( + JoinPath("./" /*test src dir*/, kTestDataDirectory, + GetParam().embedder_model_name)); + options.mutable_embedding_options()->set_l2_normalize(true); + + StatusOr<std::unique_ptr<TextSearcher>> image_searcher_or = + TextSearcher::CreateFromOptions( + options, GetOpResolver(GetParam().is_universal_sentence_encoder)); + + EXPECT_EQ(image_searcher_or.status().code(), + absl::StatusCode::kInvalidArgument); + EXPECT_THAT( + image_searcher_or.status().message(), + HasSubstr("Unable to find index file: SearchOptions.index_file is not " + "set and no AssociatedFile with type SCANN_INDEX_FILE could be " + "found in the output tensor metadata.")); + EXPECT_THAT(image_searcher_or.status().GetPayload(kTfLiteSupportPayload), + Optional(absl::Cord(absl::StrCat( + TfLiteSupportStatus::kMetadataAssociatedFileNotFoundError)))); +} + +TEST_P(CreateFromOptionsTest, FailsWithQuantization) { + TextSearcherOptions options; + options.mutable_base_options()->mutable_model_file()->set_file_name( + JoinPath("./" /*test src dir*/, kTestDataDirectory, + GetParam().embedder_model_name)); + options.mutable_embedding_options()->set_l2_normalize(true); + options.mutable_embedding_options()->set_quantize(true); + options.mutable_search_options()->mutable_index_file()->set_file_name( + JoinPath("./" /*test src dir*/, kTestDataDirectory, + GetParam().index_name)); + + StatusOr<std::unique_ptr<TextSearcher>> image_searcher_or = + TextSearcher::CreateFromOptions( + options, GetOpResolver(GetParam().is_universal_sentence_encoder)); + + EXPECT_EQ(image_searcher_or.status().code(), + absl::StatusCode::kInvalidArgument); + EXPECT_THAT(image_searcher_or.status().message(), + HasSubstr("Setting EmbeddingOptions.quantize = true is not " + "allowed in searchers")); + EXPECT_THAT(image_searcher_or.status().GetPayload(kTfLiteSupportPayload), + Optional(absl::Cord( + absl::StrCat(TfLiteSupportStatus::kInvalidArgumentError)))); +} + +TEST_P(CreateFromOptionsTest, FailsWithInvalidMaxResults) { + TextSearcherOptions options; + options.mutable_base_options()->mutable_model_file()->set_file_name( + JoinPath("./" /*test src dir*/, kTestDataDirectory, + GetParam().embedder_model_name)); + options.mutable_embedding_options()->set_l2_normalize(true); + options.mutable_search_options()->mutable_index_file()->set_file_name( + JoinPath("./" /*test src dir*/, kTestDataDirectory, + GetParam().index_name)); + options.mutable_search_options()->set_max_results(-1); + + StatusOr<std::unique_ptr<TextSearcher>> image_searcher_or = + TextSearcher::CreateFromOptions( + options, GetOpResolver(GetParam().is_universal_sentence_encoder)); + + EXPECT_EQ(image_searcher_or.status().code(), + absl::StatusCode::kInvalidArgument); + EXPECT_THAT(image_searcher_or.status().message(), + HasSubstr("SearchOptions.max_results must be > 0, found -1")); + EXPECT_THAT(image_searcher_or.status().GetPayload(kTfLiteSupportPayload), + Optional(absl::Cord( + absl::StrCat(TfLiteSupportStatus::kInvalidArgumentError)))); +} + +INSTANTIATE_TEST_SUITE_P( + CreateFromOptionsTest, + CreateFromOptionsTest, + Values(CreateFromOptionsParams{.name = "Bert", + .embedder_model_name = kMobileBertEmbedder, + .searcher_model_name = kMobileBertSearcher, + .is_universal_sentence_encoder = false, + .index_name = kMobileBertIndex}, + CreateFromOptionsParams{.name = "Regex", + .embedder_model_name = kRegexEmbedder, + .searcher_model_name = kRegexSearcher, + .is_universal_sentence_encoder = false, + .index_name = kRegexIndex}, + CreateFromOptionsParams{.name = "USE", + .embedder_model_name = kUSEEmbedder, + .searcher_model_name = kUSESearcher, + .is_universal_sentence_encoder = true, + .index_name = kUSEIndex}), + [](const TestParamInfo<CreateFromOptionsTest::ParamType>& info) { + return info.param.name; + }); + +struct SearchParams { + std::string name; + std::string embedder_model_name; + std::string searcher_model_name; + bool is_universal_sentence_encoder; + std::string index_name; + std::string expected_result; +}; + +class SearchTest : public TestWithParam<SearchParams> {}; + +TEST_P(SearchTest, SucceedsWithStandaloneIndex) { + // Create Searcher. + TextSearcherOptions options; + options.mutable_base_options()->mutable_model_file()->set_file_name( + JoinPath("./" /*test src dir*/, kTestDataDirectory, + GetParam().embedder_model_name)); + options.mutable_embedding_options()->set_l2_normalize(true); + options.mutable_search_options()->mutable_index_file()->set_file_name( + JoinPath("./" /*test src dir*/, kTestDataDirectory, + GetParam().index_name)); + SUPPORT_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<TextSearcher> searcher, + TextSearcher::CreateFromOptions( + options, GetOpResolver(GetParam().is_universal_sentence_encoder))); + + // Perform search. + SUPPORT_ASSERT_OK_AND_ASSIGN(const SearchResult& result, + searcher->Search("The weather was excellent.")); + + // Check results. + ExpectApproximatelyEqual( + result, ParseTextProtoOrDie<SearchResult>(GetParam().expected_result)); +} + +TEST_P(SearchTest, SucceedsWithMetadataIndex) { + // Create Searcher. + TextSearcherOptions options; + options.mutable_base_options()->mutable_model_file()->set_file_name( + JoinPath("./" /*test src dir*/, kTestDataDirectory, + GetParam().searcher_model_name)); + options.mutable_embedding_options()->set_l2_normalize(true); + SUPPORT_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<TextSearcher> searcher, + TextSearcher::CreateFromOptions( + options, GetOpResolver(GetParam().is_universal_sentence_encoder))); + + // Perform search. + SUPPORT_ASSERT_OK_AND_ASSIGN(const SearchResult& result, + searcher->Search("The weather was excellent.")); + + // Check results. + ExpectApproximatelyEqual( + result, ParseTextProtoOrDie<SearchResult>(GetParam().expected_result)); +} + +TEST_P(SearchTest, SucceedsWithMaxResults) { + // Create Searcher. + TextSearcherOptions options; + options.mutable_base_options()->mutable_model_file()->set_file_name( + JoinPath("./" /*test src dir*/, kTestDataDirectory, + GetParam().embedder_model_name)); + options.mutable_embedding_options()->set_l2_normalize(true); + options.mutable_search_options()->mutable_index_file()->set_file_name( + JoinPath("./" /*test src dir*/, kTestDataDirectory, + GetParam().index_name)); + options.mutable_search_options()->set_max_results(2); + SUPPORT_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<TextSearcher> searcher, + TextSearcher::CreateFromOptions( + options, GetOpResolver(GetParam().is_universal_sentence_encoder))); + + // Perform search. + SUPPORT_ASSERT_OK_AND_ASSIGN(const SearchResult& result, + searcher->Search("The weather was excellent.")); + + // Check results. + SearchResult all_results = + ParseTextProtoOrDie<SearchResult>(GetParam().expected_result); + SearchResult expected_result; + expected_result.add_nearest_neighbors()->CopyFrom( + all_results.nearest_neighbors(0)); + expected_result.add_nearest_neighbors()->CopyFrom( + all_results.nearest_neighbors(1)); + ExpectApproximatelyEqual(result, expected_result); +} + +INSTANTIATE_TEST_SUITE_P( + SearchTest, + SearchTest, + Values( + SearchParams{ + .name = "Bert", + .embedder_model_name = kMobileBertEmbedder, + .searcher_model_name = kMobileBertSearcher, + .is_universal_sentence_encoder = false, + .index_name = kMobileBertIndex, + .expected_result = R"pb( + nearest_neighbors { + metadata: "The weather was excellent." + distance: 0.0 + } + nearest_neighbors { + metadata: "It was a sunny day." + distance: 0.11537 + } + nearest_neighbors { + metadata: "The sun was shining on that day." + distance: 0.23002 + } + nearest_neighbors { + metadata: "He was very happy with his newly bought car." + distance: 0.32456 + } + nearest_neighbors { + metadata: "The cat is chasing after the mouse." + distance: 0.96693 + } + )pb"}, + SearchParams{ + .name = "Regex", + .embedder_model_name = kRegexEmbedder, + .searcher_model_name = kRegexSearcher, + .is_universal_sentence_encoder = false, + .index_name = kRegexIndex, + .expected_result = R"pb( + nearest_neighbors { + metadata: "The weather was excellent." + distance: 0.0 + } + nearest_neighbors { + metadata: "The sun was shining on that day." + distance: 0.00006 + } + nearest_neighbors { + metadata: "The cat is chasing after the mouse." + distance: 0.00009 + } + nearest_neighbors { + metadata: "It was a sunny day." + distance: 0.00011 + } + nearest_neighbors { + metadata: "He was very happy with his newly bought car." + distance: 0.00012 + } + )pb"}, + SearchParams{ + .name = "USE", + .embedder_model_name = kUSEEmbedder, + .searcher_model_name = kUSESearcher, + .is_universal_sentence_encoder = true, + .index_name = kUSEIndex, + .expected_result = R"pb( + nearest_neighbors { + metadata: "The weather was excellent." + distance: 0.0 + } + nearest_neighbors { + metadata: "It was a sunny day." + distance: 0.14636 + } + nearest_neighbors { + metadata: "The sun was shining on that day." + distance: 0.15222 + } + nearest_neighbors { + metadata: "The cat is chasing after the mouse." + distance: 0.35997 + } + nearest_neighbors { + metadata: "He was very happy with his newly bought car." + distance: 0.36693 + } + )pb"}), + [](const TestParamInfo<SearchTest::ParamType>& info) { + return info.param.name; + }); + +} // namespace +} // namespace text +} // namespace task +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/BUILD index 73a9a02d..b4b0432b 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/BUILD
@@ -34,8 +34,8 @@ "//tensorflow_lite_support/cc/task/vision/proto:image_classifier_options_proto_inc", "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_common_utils", "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_utils", + "//tensorflow_lite_support/cc/task/vision/utils:image_utils", "//tensorflow_lite_support/cc/test:test_utils", - "//tensorflow_lite_support/examples/task/vision/desktop/utils:image_utils", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:cord", @@ -69,7 +69,7 @@ "//tensorflow_lite_support/cc/task/vision/proto:classifications_proto_inc", "//tensorflow_lite_support/cc/task/vision/proto:image_classifier_options_proto_inc", "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_common_utils", - "//tensorflow_lite_support/examples/task/vision/desktop/utils:image_utils", + "//tensorflow_lite_support/cc/task/vision/utils:image_utils", "@com_google_absl//absl/status", ], ) @@ -99,8 +99,8 @@ "//tensorflow_lite_support/cc/task/vision/proto:object_detector_options_proto_inc", "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_common_utils", "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_utils", + "//tensorflow_lite_support/cc/task/vision/utils:image_utils", "//tensorflow_lite_support/cc/test:test_utils", - "//tensorflow_lite_support/examples/task/vision/desktop/utils:image_utils", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:cord", @@ -133,8 +133,8 @@ "//tensorflow_lite_support/cc/task/vision/proto:segmentations_proto_inc", "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_common_utils", "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_utils", + "//tensorflow_lite_support/cc/task/vision/utils:image_utils", "//tensorflow_lite_support/cc/test:test_utils", - "//tensorflow_lite_support/examples/task/vision/desktop/utils:image_utils", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:cord", @@ -165,8 +165,8 @@ "//tensorflow_lite_support/cc/task/vision/proto:image_embedder_options_proto_inc", "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_common_utils", "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_utils", + "//tensorflow_lite_support/cc/task/vision/utils:image_utils", "//tensorflow_lite_support/cc/test:test_utils", - "//tensorflow_lite_support/examples/task/vision/desktop/utils:image_utils", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/status", "@org_tensorflow//tensorflow/lite:framework", @@ -174,3 +174,40 @@ "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", ], ) + +cc_test_with_tflite( + name = "image_searcher_test", + srcs = ["image_searcher_test.cc"], + data = [ + "//tensorflow_lite_support/cc/test/testdata/task/vision:test_images", + "//tensorflow_lite_support/cc/test/testdata/task/vision:test_indices", + "//tensorflow_lite_support/cc/test/testdata/task/vision:test_models", + ], + tflite_deps = [ + "//tensorflow_lite_support/cc/task/vision:image_searcher", + "@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util", + ], + deps = [ + "//tensorflow_lite_support/cc:common", + "//tensorflow_lite_support/cc/port:gtest_main", + "//tensorflow_lite_support/cc/port:status_macros", + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/task/core/proto:base_options_cc_proto", + "//tensorflow_lite_support/cc/task/processor/proto:embedding_options_cc_proto", + "//tensorflow_lite_support/cc/task/processor/proto:search_options_cc_proto", + "//tensorflow_lite_support/cc/task/processor/proto:search_result_cc_proto", + "//tensorflow_lite_support/cc/task/vision/core:frame_buffer", + "//tensorflow_lite_support/cc/task/vision/proto:image_searcher_options_cc_proto", + "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_common_utils", + "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_utils", + "//tensorflow_lite_support/cc/task/vision/utils:image_utils", + "//tensorflow_lite_support/cc/test:test_utils", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + "@org_tensorflow//tensorflow/lite:framework", + "@org_tensorflow//tensorflow/lite/c:common", + "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_classifier_coral_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_classifier_coral_test.cc index 7e6a311a..887e90d 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_classifier_coral_test.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_classifier_coral_test.cc
@@ -26,7 +26,7 @@ #include "tensorflow_lite_support/cc/task/vision/proto/classifications_proto_inc.h" #include "tensorflow_lite_support/cc/task/vision/proto/image_classifier_options_proto_inc.h" #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h" -#include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h" +#include "tensorflow_lite_support/cc/task/vision/utils/image_utils.h" namespace tflite { namespace task {
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_classifier_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_classifier_test.cc index c40836e..2daf293b 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_classifier_test.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_classifier_test.cc
@@ -38,8 +38,8 @@ #include "tensorflow_lite_support/cc/task/vision/proto/image_classifier_options_proto_inc.h" #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h" #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h" +#include "tensorflow_lite_support/cc/task/vision/utils/image_utils.h" #include "tensorflow_lite_support/cc/test/test_utils.h" -#include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h" namespace tflite { namespace task {
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_embedder_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_embedder_test.cc index 8877f28..41226f6 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_embedder_test.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_embedder_test.cc
@@ -34,8 +34,8 @@ #include "tensorflow_lite_support/cc/task/vision/proto/image_embedder_options_proto_inc.h" #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h" #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h" +#include "tensorflow_lite_support/cc/task/vision/utils/image_utils.h" #include "tensorflow_lite_support/cc/test/test_utils.h" -#include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h" namespace tflite { namespace task {
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_searcher_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_searcher_test.cc new file mode 100644 index 0000000..00183eb --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_searcher_test.cc
@@ -0,0 +1,284 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/cc/task/vision/image_searcher.h" + +#include <memory> +#include <string> + +#include "absl/flags/flag.h" // from @com_google_absl +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/cord.h" // from @com_google_absl +#include "absl/strings/str_cat.h" // from @com_google_absl +#include "tensorflow/lite/core/shims/cc/shims_test_util.h" +#include "tensorflow_lite_support/cc/common.h" +#include "tensorflow_lite_support/cc/port/gmock.h" +#include "tensorflow_lite_support/cc/port/gtest.h" +#include "tensorflow_lite_support/cc/port/status_matchers.h" +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/core/proto/base_options.pb.h" +#include "tensorflow_lite_support/cc/task/processor/proto/embedding_options.pb.h" +#include "tensorflow_lite_support/cc/task/processor/proto/search_options.pb.h" +#include "tensorflow_lite_support/cc/task/processor/proto/search_result.pb.h" +#include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h" +#include "tensorflow_lite_support/cc/task/vision/proto/image_searcher_options.pb.h" +#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h" +#include "tensorflow_lite_support/cc/task/vision/utils/image_utils.h" +#include "tensorflow_lite_support/cc/test/test_utils.h" + +namespace tflite { +namespace task { +namespace vision { +namespace { + +using ::testing::HasSubstr; +using ::testing::Optional; +using ::tflite::support::kTfLiteSupportPayload; +using ::tflite::support::StatusOr; +using ::tflite::support::TfLiteSupportStatus; +using ::tflite::task::processor::NearestNeighbor; +using ::tflite::task::processor::SearchResult; + +constexpr char kTestDataDirectory[] = + "/tensorflow_lite_support/cc/test/testdata/task/" + "vision/"; +// Test embedder model. Float inputs, produces feature vectors that are not +// L2-normalized as this model doesn't include a L2_NORMALIZATION TFLite Op. +constexpr char kMobileNetV3Embedder[] = + "mobilenet_v3_small_100_224_embedder.tflite"; +// Standalone test index. +constexpr char kIndex[] = "searcher_index.ldb"; +// Test searcher model. Identical to kMobileNetV3Embedder, but with the contents +// of kIndex baked into the model metadata. +constexpr char kMobileNetV3Searcher[] = + "mobilenet_v3_small_100_224_searcher.tflite"; + +StatusOr<ImageData> LoadImage(std::string image_name) { + return DecodeImageFromFile( + JoinPath("./" /*test src dir*/, kTestDataDirectory, image_name)); +} + +// Checks that the two provided `SearchResult` protos are equal, with a +// tolerancy on floating-point scores to account for numerical instabilities. +void ExpectApproximatelyEqual(const SearchResult& actual, + const SearchResult& expected) { + const float kPrecision = 1e-5; + EXPECT_EQ(actual.nearest_neighbors_size(), expected.nearest_neighbors_size()); + for (int i = 0; i < actual.nearest_neighbors_size(); ++i) { + const NearestNeighbor& a = actual.nearest_neighbors(i); + const NearestNeighbor& b = expected.nearest_neighbors(i); + EXPECT_EQ(a.metadata(), b.metadata()); + EXPECT_NEAR(a.distance(), b.distance(), kPrecision); + } +} + +class CreateFromOptionsTest : public tflite_shims::testing::Test {}; + +TEST_F(CreateFromOptionsTest, SucceedsWithStandaloneIndex) { + ImageSearcherOptions options; + options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath( + "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3Embedder)); + options.mutable_embedding_options()->set_l2_normalize(true); + options.mutable_search_options()->mutable_index_file()->set_file_name( + JoinPath("./" /*test src dir*/, kTestDataDirectory, kIndex)); + + SUPPORT_ASSERT_OK(ImageSearcher::CreateFromOptions(options)); +} + +TEST_F(CreateFromOptionsTest, SucceedsWithMetadataIndex) { + ImageSearcherOptions options; + options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath( + "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3Searcher)); + options.mutable_embedding_options()->set_l2_normalize(true); + + SUPPORT_ASSERT_OK(ImageSearcher::CreateFromOptions(options)); +} + +TEST_F(CreateFromOptionsTest, FailsWithMissingModel) { + ImageSearcherOptions options; + options.mutable_embedding_options()->set_l2_normalize(true); + options.mutable_search_options()->mutable_index_file()->set_file_name( + JoinPath("./" /*test src dir*/, kTestDataDirectory, kIndex)); + + StatusOr<std::unique_ptr<ImageSearcher>> image_searcher_or = + ImageSearcher::CreateFromOptions(options); + + EXPECT_EQ(image_searcher_or.status().code(), + absl::StatusCode::kInvalidArgument); + EXPECT_THAT( + image_searcher_or.status().message(), + HasSubstr("Missing mandatory `model_file` field in `base_options`")); + EXPECT_THAT(image_searcher_or.status().GetPayload(kTfLiteSupportPayload), + Optional(absl::Cord( + absl::StrCat(TfLiteSupportStatus::kInvalidArgumentError)))); +} + +TEST_F(CreateFromOptionsTest, FailsWithMissingIndex) { + ImageSearcherOptions options; + options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath( + "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3Embedder)); + options.mutable_embedding_options()->set_l2_normalize(true); + + StatusOr<std::unique_ptr<ImageSearcher>> image_searcher_or = + ImageSearcher::CreateFromOptions(options); + + EXPECT_EQ(image_searcher_or.status().code(), + absl::StatusCode::kInvalidArgument); + EXPECT_THAT( + image_searcher_or.status().message(), + HasSubstr("Unable to find index file: SearchOptions.index_file is not " + "set and no AssociatedFile with type SCANN_INDEX_FILE could be " + "found in the output tensor metadata.")); + EXPECT_THAT(image_searcher_or.status().GetPayload(kTfLiteSupportPayload), + Optional(absl::Cord(absl::StrCat( + TfLiteSupportStatus::kMetadataAssociatedFileNotFoundError)))); +} + +TEST_F(CreateFromOptionsTest, FailsWithQuantization) { + ImageSearcherOptions options; + options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath( + "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3Embedder)); + options.mutable_embedding_options()->set_l2_normalize(true); + options.mutable_embedding_options()->set_quantize(true); + options.mutable_search_options()->mutable_index_file()->set_file_name( + JoinPath("./" /*test src dir*/, kTestDataDirectory, kIndex)); + + StatusOr<std::unique_ptr<ImageSearcher>> image_searcher_or = + ImageSearcher::CreateFromOptions(options); + + EXPECT_EQ(image_searcher_or.status().code(), + absl::StatusCode::kInvalidArgument); + EXPECT_THAT(image_searcher_or.status().message(), + HasSubstr("Setting EmbeddingOptions.quantize = true is not " + "allowed in searchers")); + EXPECT_THAT(image_searcher_or.status().GetPayload(kTfLiteSupportPayload), + Optional(absl::Cord( + absl::StrCat(TfLiteSupportStatus::kInvalidArgumentError)))); +} + +TEST_F(CreateFromOptionsTest, FailsWithInvalidMaxResults) { + ImageSearcherOptions options; + options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath( + "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3Embedder)); + options.mutable_embedding_options()->set_l2_normalize(true); + options.mutable_search_options()->mutable_index_file()->set_file_name( + JoinPath("./" /*test src dir*/, kTestDataDirectory, kIndex)); + options.mutable_search_options()->set_max_results(-1); + + StatusOr<std::unique_ptr<ImageSearcher>> image_searcher_or = + ImageSearcher::CreateFromOptions(options); + + EXPECT_EQ(image_searcher_or.status().code(), + absl::StatusCode::kInvalidArgument); + EXPECT_THAT(image_searcher_or.status().message(), + HasSubstr("SearchOptions.max_results must be > 0, found -1")); + EXPECT_THAT(image_searcher_or.status().GetPayload(kTfLiteSupportPayload), + Optional(absl::Cord( + absl::StrCat(TfLiteSupportStatus::kInvalidArgumentError)))); +} + +TEST(SearchTest, SucceedsWithStandaloneIndex) { + // Create Searcher. + ImageSearcherOptions options; + options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath( + "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3Embedder)); + options.mutable_embedding_options()->set_l2_normalize(true); + options.mutable_search_options()->mutable_index_file()->set_file_name( + JoinPath("./" /*test src dir*/, kTestDataDirectory, kIndex)); + SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSearcher> searcher, + ImageSearcher::CreateFromOptions(options)); + // Load image. + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image, LoadImage("burger.jpg")); + std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer( + image.pixel_data, FrameBuffer::Dimension{image.width, image.height}); + + // Perform search. + SUPPORT_ASSERT_OK_AND_ASSIGN(const SearchResult& result, + searcher->Search(*frame_buffer)); + ImageDataFree(&image); + + // Check results. + ExpectApproximatelyEqual( + result, ParseTextProtoOrDie<SearchResult>(R"pb( + nearest_neighbors { metadata: "burger" distance: 0.0 } + nearest_neighbors { metadata: "car" distance: 1.82244 } + nearest_neighbors { metadata: "bird" distance: 1.93094 } + nearest_neighbors { metadata: "dog" distance: 2.04736 } + nearest_neighbors { metadata: "cat" distance: 2.07587 } + )pb")); +} + +TEST(SearchTest, SucceedsWithMetadataIndex) { + // Create Searcher. + ImageSearcherOptions options; + options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath( + "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3Searcher)); + options.mutable_embedding_options()->set_l2_normalize(true); + SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSearcher> searcher, + ImageSearcher::CreateFromOptions(options)); + // Load image. + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image, LoadImage("burger.jpg")); + std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer( + image.pixel_data, FrameBuffer::Dimension{image.width, image.height}); + + // Perform search. + SUPPORT_ASSERT_OK_AND_ASSIGN(const SearchResult& result, + searcher->Search(*frame_buffer)); + ImageDataFree(&image); + + // Check results. + ExpectApproximatelyEqual( + result, ParseTextProtoOrDie<SearchResult>(R"pb( + nearest_neighbors { metadata: "burger" distance: 0.0 } + nearest_neighbors { metadata: "car" distance: 1.82244 } + nearest_neighbors { metadata: "bird" distance: 1.93094 } + nearest_neighbors { metadata: "dog" distance: 2.04736 } + nearest_neighbors { metadata: "cat" distance: 2.07587 } + )pb")); +} + +TEST(SearchTest, SucceedsWithMaxResults) { + // Create Searcher. + ImageSearcherOptions options; + options.mutable_base_options()->mutable_model_file()->set_file_name(JoinPath( + "./" /*test src dir*/, kTestDataDirectory, kMobileNetV3Embedder)); + options.mutable_embedding_options()->set_l2_normalize(true); + options.mutable_search_options()->mutable_index_file()->set_file_name( + JoinPath("./" /*test src dir*/, kTestDataDirectory, kIndex)); + options.mutable_search_options()->set_max_results(2); + SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSearcher> searcher, + ImageSearcher::CreateFromOptions(options)); + // Load image. + SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image, LoadImage("burger.jpg")); + std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer( + image.pixel_data, FrameBuffer::Dimension{image.width, image.height}); + + // Perform search. + SUPPORT_ASSERT_OK_AND_ASSIGN(const SearchResult& result, + searcher->Search(*frame_buffer)); + ImageDataFree(&image); + + // Check results. + ExpectApproximatelyEqual( + result, ParseTextProtoOrDie<SearchResult>(R"pb( + nearest_neighbors { metadata: "burger" distance: 0.0 } + nearest_neighbors { metadata: "car" distance: 1.82244 } + )pb")); +} + +} // namespace +} // namespace vision +} // namespace task +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_segmenter_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_segmenter_test.cc index dc768a4..8671b68 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_segmenter_test.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/image_segmenter_test.cc
@@ -37,9 +37,9 @@ #include "tensorflow_lite_support/cc/task/vision/proto/segmentations_proto_inc.h" #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h" #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h" +#include "tensorflow_lite_support/cc/task/vision/utils/image_utils.h" #include "tensorflow_lite_support/cc/test/message_matchers.h" #include "tensorflow_lite_support/cc/test/test_utils.h" -#include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h" namespace tflite { namespace task {
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/object_detector_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/object_detector_test.cc index 4a33e4b..6c0f395 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/object_detector_test.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/task/vision/object_detector_test.cc
@@ -38,9 +38,9 @@ #include "tensorflow_lite_support/cc/task/vision/proto/object_detector_options_proto_inc.h" #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h" #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h" +#include "tensorflow_lite_support/cc/task/vision/utils/image_utils.h" #include "tensorflow_lite_support/cc/test/message_matchers.h" #include "tensorflow_lite_support/cc/test/test_utils.h" -#include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h" namespace tflite {
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/BUILD index 9b541f5..43dd738 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/BUILD
@@ -91,3 +91,19 @@ name = "30k-clean", extension = "model", ) + +filegroup( + name = "test_indices", + srcs = glob([ + "*.ldb", + ]), +) + +filegroup( + name = "test_searchers", + srcs = [ + "mobilebert_searcher.tflite", + "regex_searcher.tflite", + "universal_sentence_encoder_searcher.tflite", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/mobilebert_index.ldb b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/mobilebert_index.ldb new file mode 100644 index 0000000..fca3e7d --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/mobilebert_index.ldb Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/regex_index.ldb b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/regex_index.ldb new file mode 100644 index 0000000..c1e853548 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/regex_index.ldb Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/regex_searcher.tflite b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/regex_searcher.tflite new file mode 100644 index 0000000..b82e6d8 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/regex_searcher.tflite Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/universal_sentence_encoder_index.ldb b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/universal_sentence_encoder_index.ldb new file mode 100644 index 0000000..88ed034 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/universal_sentence_encoder_index.ldb Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/universal_sentence_encoder_searcher.tflite b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/universal_sentence_encoder_searcher.tflite new file mode 100644 index 0000000..877434a0 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text/universal_sentence_encoder_searcher.tflite Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/BUILD index 6de3411..1b83d3e 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/BUILD
@@ -19,3 +19,10 @@ "*.png", ]), ) + +filegroup( + name = "test_indices", + srcs = glob([ + "*.ldb", + ]), +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/mobilenet_v3_small_100_224_searcher.tflite b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/mobilenet_v3_small_100_224_searcher.tflite new file mode 100644 index 0000000..47b5fdd --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/mobilenet_v3_small_100_224_searcher.tflite Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/searcher_index.ldb b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/searcher_index.ldb new file mode 100644 index 0000000..923f27a0 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/vision/searcher_index.ldb Binary files differ
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/jni_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/jni_utils.h index 7caf49e..f92f838 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/jni_utils.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/utils/jni_utils.h
@@ -54,24 +54,28 @@ return std::forward<T>(t); } -// Converts a std::vector<T> into a Java ArrayList using a converter, which -// processes a single element in the vector before adding it to the ArrayList. -template <typename T> -jobject ConvertVectorToArrayList(JNIEnv* env, - const std::vector<T>& results, - std::function<jobject(T)> converter) { +// Converts an interable (specified by iterators, `begin` and `end`) into +// a Java ArrayList using a converter, which processes a single element in the +// interable before adding it to the ArrayList. +template <typename Iterator> +jobject ConvertVectorToArrayList( + JNIEnv* env, + const Iterator& begin, + const Iterator& end, + std::function<jobject(typename std::iterator_traits<Iterator>::value_type)> + converter) { jclass array_list_class = env->FindClass("java/util/ArrayList"); jmethodID array_list_ctor = env->GetMethodID(array_list_class, "<init>", "(I)V"); - jint initial_capacity = static_cast<jint>(results.size()); + jint initial_capacity = static_cast<jint>(std::distance(begin, end)); jobject array_list_object = env->NewObject(array_list_class, array_list_ctor, initial_capacity); jmethodID array_list_add_method = env->GetMethodID(array_list_class, "add", "(Ljava/lang/Object;)Z"); - for (const auto& ans : results) { + for (auto it = begin; it != end; ++it) { env->CallBooleanMethod(array_list_object, array_list_add_method, - converter(ans)); + converter(*it)); } return array_list_object; }
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ngrams_test.py b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ngrams_test.py index 33a2231..08e6271 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ngrams_test.py +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/ngrams_test.py
@@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -# Lint as: python3 """Tests for tensorflow_lite_support.custom_ops.ngrams.""" import os
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_test.py b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_test.py index b6a1a67..70de237 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_test.py +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_test.py
@@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -# Lint as: python3 """Tests for tensorflow_lite_support.custom_ops.kernel.whitespace_tokenizer.""" import os
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/python/sentencepiece_tokenizer.py b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/python/sentencepiece_tokenizer.py index 21efed56..341d387 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/python/sentencepiece_tokenizer.py +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/python/sentencepiece_tokenizer.py
@@ -13,7 +13,6 @@ # limitations under the License. # ============================================================================== -# Lint as: python3 """Python class that implements Sentencepiece tokenizer. It follows TF.text designers design.
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/python/sentencepiece_tokenizer_test.py b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/python/sentencepiece_tokenizer_test.py index 3609b469..de5bc65 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/python/sentencepiece_tokenizer_test.py +++ b/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/python/sentencepiece_tokenizer_test.py
@@ -13,7 +13,6 @@ # limitations under the License. # ============================================================================== -# Lint as: python3 """Tests for sentencepiece_tokenizer.""" import os
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/colab/on_device_text_to_image_search_tflite.ipynb b/third_party/tflite_support/src/tensorflow_lite_support/examples/colab/on_device_text_to_image_search_tflite.ipynb new file mode 100644 index 0000000..dc7dd25d --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/colab/on_device_text_to_image_search_tflite.ipynb
@@ -0,0 +1,1263 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "k91nHOvT6I11" + }, + "source": [ + "##### Copyright 2022 The TensorFlow Authors." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "v1_OyHN36JyC" + }, + "outputs": [], + "source": [ + "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "C5PO5cHLFDT1" + }, + "source": [ + "# On-device Text-to-Image Search with TensorFlow Lite Searcher Library\n", + "\n", + "In this colab, we showcase an end to end example of how to train an image-text dual encoder model and how to perform retrieval with TFLite Searcher Library. We are going to use the [COCO 2014](https://cocodataset.org/#home) dataset, and in the end you'll be able to retrieve images using a text description.\n", + "\n", + "First, we need to encode the images into high-dimensional vectors. Then we index them with [Model Maker Searcher API](https://www.tensorflow.org/lite/api_docs/python/tflite_model_maker/searcher/). During inference, a TFLite text embedder encodes the text query into another high-dimensional vector in the same embedding space, and invokes the [on-device ScaNN searcher](https://github.com/tensorflow/tflite-support/tree/master/tensorflow_lite_support/scann_ondevice) to retrieve similar images.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2Tc6uMrczn4g" + }, + "source": [ + "You can download the pre-trained searcher model packed with ScaNN index from [here](https://storage.googleapis.com/download.tensorflow.org/models/tflite_support/searcher/text_to_image_blogpost/searcher_model.tflite) and skip to [inference](#scrollTo=EeZwqEnxW5Xl). Be sure to name it `searcher_model.tflite` and upload it to colab under the current working directory." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "8KBOd6FudlqM" + }, + "outputs": [], + "source": [ + "!pip install -q -U tensorflow tensorflow-hub tensorflow-addons\n", + "!pip install -q -U tflite-support\n", + "!pip install -q -U tflite-model-maker\n", + "!pip install -q -U tensorflow-text==2.10.0b2\n", + "!sudo apt-get -qq install libportaudio2 # Needed by tflite-support" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SXYcCLchJXil" + }, + "source": [ + "Note you might need to restart the runtime after installation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "CDDQxg1RpGPZ" + }, + "outputs": [], + "source": [ + "import json\n", + "import math\n", + "import os\n", + "import pickle\n", + "import random\n", + "import shutil\n", + "import matplotlib.pyplot as plt\n", + "import tensorflow as tf\n", + "from tensorflow import keras\n", + "import tensorflow.compat.v1 as tf1\n", + "from tensorflow.keras import layers\n", + "import tensorflow_addons as tfa\n", + "import tensorflow_hub as hub\n", + "import tensorflow_text as text\n", + "from tensorflow_text.python.ops import fast_sentencepiece_tokenizer as sentencepiece_tokenizer\n", + "\n", + "# Suppressing tf.hub warnings\n", + "tf.get_logger().setLevel('ERROR')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "6yNYV-uBHtRY" + }, + "outputs": [], + "source": [ + "DATASET_DIR = 'datasets'\n", + "CAPTION_URL = 'http://images.cocodataset.org/annotations/annotations_trainval2014.zip'\n", + "TRAIN_IMAGE_URL = 'http://images.cocodataset.org/zips/train2014.zip'\n", + "VALID_IMAGE_URL = 'http://images.cocodataset.org/zips/val2014.zip'\n", + "TRAIN_IMAGE_DIR = os.path.join(DATASET_DIR, 'train2014')\n", + "VALID_IMAGE_DIR = os.path.join(DATASET_DIR, 'val2014')\n", + "TRAIN_IMAGE_PREFIX = 'COCO_train2014_'\n", + "VALID_IMAGE_PREFIX = 'COCO_val2014_'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "niIHzEnzJ8HR" + }, + "outputs": [], + "source": [ + "IMAGE_SIZE = (384, 384)\n", + "EFFICIENT_NET_URL = 'https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_ft1k_s/feature_vector/2'\n", + "UNIVERSAL_SENTENCE_ENCODER_URL = 'https://tfhub.dev/google/universal-sentence-encoder-lite/2'\n", + "\n", + "BATCH_SIZE = 256\n", + "NUM_EPOCHS = 10\n", + "SEQ_LENGTH = 128\n", + "EMB_SIZE = 128" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "tPvDrjQ9FBNw" + }, + "source": [ + "## Get COCO dataset\n", + "\n", + "We are not using Tensorflow Dataset to get the [coco_captions](https://www.tensorflow.org/datasets/catalog/coco_captions) dataset due to disk space concerns. The following code will download and process the dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "FrQXA95HGzYN" + }, + "outputs": [], + "source": [ + "#@title Functions for downloading and parsing annotations.\n", + "\n", + "def parse_annotation_json(json_path):\n", + " # Assuming the json file is already downloaded.\n", + " with open(json_path, 'r') as f:\n", + " json_obj = json.load(f)\n", + "\n", + " # Parsing out the following information from the annotation json: the COCO\n", + " # image id and their corresponding flickr post id, as well as the captions.\n", + " mapping = dict()\n", + " for caption in json_obj['annotations']:\n", + " image_id = caption['image_id']\n", + " if image_id not in mapping:\n", + " mapping[image_id] = [[]]\n", + " mapping[image_id][0].append(caption['caption'])\n", + " for image in json_obj['images']:\n", + " # The flickr url here is the CDN url. We need to split it to get the post\n", + " # id.\n", + " flickr_url = image['flickr_url']\n", + " url_parts = flickr_url.split('/')\n", + " flickr_id = url_parts[-1].split('_')[0]\n", + " mapping[image['id']].append(flickr_id)\n", + " return list(mapping.items())\n", + "\n", + "\n", + "def get_train_valid_captions():\n", + " # Parse and cache the annotation for train and valid\n", + " train_pickle_path = os.path.join(DATASET_DIR, 'train_captions.pickle')\n", + " valid_pickle_path = os.path.join(DATASET_DIR, 'valid_captions.pickle')\n", + "\n", + " if not os.path.exists(train_pickle_path) or not os.path.exists(\n", + " valid_pickle_path):\n", + " # Parse and cache the annotations if they don't exist\n", + " annotation_zip = tf.keras.utils.get_file(\n", + " 'annotations.zip',\n", + " cache_dir=os.path.abspath('.'),\n", + " cache_subdir=os.path.join(DATASET_DIR, 'tmp'),\n", + " origin=CAPTION_URL,\n", + " extract=True,\n", + " )\n", + " os.remove(annotation_zip)\n", + " train_img_cap = parse_annotation_json(\n", + " os.path.join(DATASET_DIR, 'tmp', 'annotations',\n", + " 'captions_train2014.json'))\n", + " valid_img_cap = parse_annotation_json(\n", + " os.path.join(DATASET_DIR, 'tmp', 'annotations',\n", + " 'captions_val2014.json'))\n", + " with open(train_pickle_path, 'wb') as f:\n", + " pickle.dump(train_img_cap, f)\n", + " with open(valid_pickle_path, 'wb') as f:\n", + " pickle.dump(valid_img_cap, f)\n", + " shutil.rmtree(os.path.join(DATASET_DIR, 'tmp'))\n", + " else:\n", + " # Load the cached annotations\n", + " with open(train_pickle_path, 'rb') as f:\n", + " train_img_cap = pickle.load(f)\n", + " with open(valid_pickle_path, 'rb') as f:\n", + " valid_img_cap = pickle.load(f)\n", + " return train_img_cap, valid_img_cap" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "OTOPBL57a74w" + }, + "outputs": [], + "source": [ + "#@title Functions for downloading the images and create the dataset.\n", + "\n", + "def get_sentencepiece_tokenizer_in_tf2():\n", + " # The universal sentence encoder model from TFHub is in TF1 Module format. We\n", + " # need to directly access the asset_paths to get the sentencepiece tokenizer\n", + " # proto path.\n", + " module = hub.load(UNIVERSAL_SENTENCE_ENCODER_URL)\n", + " spm_path = module.asset_paths[0].asset_path.numpy()\n", + " with tf.io.gfile.GFile(spm_path, mode='rb') as f:\n", + " return sentencepiece_tokenizer.FastSentencepieceTokenizer(f.read())\n", + "\n", + "\n", + "def prepare_dataset(id_image_info_list,\n", + " image_file_prefix,\n", + " image_dir,\n", + " image_zip_url,\n", + " shuffle=False):\n", + " # Download and unzip the dataset if it's not there already.\n", + " if not os.path.exists(image_dir):\n", + " image_zip = tf.keras.utils.get_file(\n", + " 'image.zip',\n", + " cache_dir=os.path.abspath('.'),\n", + " cache_subdir=os.path.join(DATASET_DIR),\n", + " origin=image_zip_url,\n", + " extract=True,\n", + " )\n", + " os.remove(image_zip)\n", + "\n", + " # Convert the lists into tensors so that we can index into it in the dataset\n", + " # transformations later.\n", + " coco_ids, image_info = zip(*id_image_info_list)\n", + " captions, flickr_ids = zip(*image_info)\n", + " file_names = list(\n", + " map(\n", + " lambda id: os.path.join(image_dir, '%s%012d.jpg' %\n", + " (image_file_prefix, id)), coco_ids))\n", + " coco_ids_tensor = tf.constant(coco_ids)\n", + " captions_tensor = tf.ragged.constant(captions)\n", + " file_names_tensor = tf.constant(file_names)\n", + " flickr_ids_tensor = tf.constant(flickr_ids)\n", + "\n", + " # The initial dataset only contains the index. This is to make sure the\n", + " # dataset has a known size.\n", + " dataset = tf.data.Dataset.range(len(coco_ids))\n", + "\n", + " sp = get_sentencepiece_tokenizer_in_tf2()\n", + "\n", + " def _load_image_and_select_caption(i):\n", + " image_id = coco_ids_tensor[i]\n", + " captions = captions_tensor[i]\n", + " image_path = file_names_tensor[i]\n", + " flickr_id = flickr_ids_tensor[i]\n", + " image = tf.image.decode_jpeg(tf.io.read_file(image_path), channels=3)\n", + "\n", + " # Randomly select one caption from the many captions we have for each image\n", + " caption_idx = tf.random.uniform((1,),\n", + " minval=0,\n", + " maxval=tf.shape(captions)[0],\n", + " dtype=tf.int32)[0]\n", + " caption = captions[caption_idx]\n", + " caption = tf.sparse.from_dense(sp.tokenize(caption))\n", + " example = {\n", + " 'image': image,\n", + " 'image_id': image_id,\n", + " 'caption': caption,\n", + " 'flickr_id': flickr_id\n", + " }\n", + " return example\n", + "\n", + " def _resize_image(example):\n", + " # Efficient net requires the pixels to be in range of [0, 1].\n", + " example['image'] = tf.image.resize(example['image'], size=IMAGE_SIZE) / 255\n", + " return example\n", + "\n", + " dataset = (\n", + " # Load the images from disk and decode them into numpy arrays.\n", + " dataset.map(\n", + " _load_image_and_select_caption,\n", + " num_parallel_calls=tf.data.AUTOTUNE,\n", + " deterministic=not shuffle)\n", + "\n", + " # Resizing image is slow. We put the stage into a separate map so that it\n", + " # could get more threads to not be the bottleneck.\n", + " .map(\n", + " _resize_image,\n", + " num_parallel_calls=tf.data.AUTOTUNE,\n", + " deterministic=not shuffle))\n", + "\n", + " if shuffle:\n", + " dataset = dataset.shuffle(BATCH_SIZE * 10)\n", + "\n", + " dataset = dataset.batch(BATCH_SIZE)\n", + " return dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Kzpigw9ozZOM" + }, + "source": [ + "Download the datasets and preprocess them." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "executionInfo": { + "elapsed": 16948, + "status": "ok", + "timestamp": 1651885239693, + "user": { + "displayName": "Zonglin Li", + "userId": "11843710831668693042" + }, + "user_tz": 240 + }, + "id": "pHbgdBfFWmtz", + "outputId": "38b5e03b-4c19-430f-af0a-48019383540e" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Downloading data from http://images.cocodataset.org/annotations/annotations_trainval2014.zip\n", + "252878848/252872794 [==============================] - 8s 0us/step\n", + "252887040/252872794 [==============================] - 8s 0us/step\n", + "Train number of images: 82783\n", + "Valid number of images: 40504\n", + "COCO image id: 318556\n", + "Captions: ['A very clean and well decorated empty bathroom', 'A blue and white bathroom with butterfly themed wall tiles.', 'A bathroom with a border of butterflies and blue paint on the walls above it.', 'An angled view of a beautifully decorated bathroom.', 'A clock that blends in with the wall hangs in a bathroom. ']\n", + "Flickr post url: http://flickr.com/photo.gne?id=3378902101\n" + ] + } + ], + "source": [ + "# We parse the caption json files first.\n", + "train_img_cap, valid_img_cap = get_train_valid_captions()\n", + "print(f'Train number of images: {len(train_img_cap)}')\n", + "print(f'Valid number of images: {len(valid_img_cap)}')\n", + "\n", + "example = train_img_cap[0]\n", + "print(f'COCO image id: {example[0]}')\n", + "print(f'Captions: {example[1][0]}')\n", + "print(f'Flickr post url: http://flickr.com/photo.gne?id={example[1][1]}')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "executionInfo": { + "elapsed": 775219, + "status": "ok", + "timestamp": 1651886014906, + "user": { + "displayName": "Zonglin Li", + "userId": "11843710831668693042" + }, + "user_tz": 240 + }, + "id": "Ke6EeKAqj1vB", + "outputId": "4c550552-270b-435a-e8d7-a73c92da9ef9" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Downloading data from http://images.cocodataset.org/zips/val2014.zip\n", + "6645014528/6645013297 [==============================] - 183s 0us/step\n", + "6645022720/6645013297 [==============================] - 183s 0us/step\n", + "Downloading data from http://images.cocodataset.org/zips/train2014.zip\n", + "13510574080/13510573713 [==============================] - 412s 0us/step\n", + "13510582272/13510573713 [==============================] - 412s 0us/step\n" + ] + } + ], + "source": [ + "# Shuffle both the train and validation sets\n", + "random.shuffle(valid_img_cap)\n", + "random.shuffle(train_img_cap)\n", + "\n", + "# We randomly sample 5000 image-caption pairs from validation set for validation\n", + "# during training, to match the setup of\n", + "# https://www.tensorflow.org/datasets/catalog/coco_captions. However, when\n", + "# generating the retrieval database later on, we will use all the images in both\n", + "# validation and training splits.\n", + "valid_dataset = prepare_dataset(\n", + " valid_img_cap[:5000],\n", + " VALID_IMAGE_PREFIX,\n", + " VALID_IMAGE_DIR,\n", + " VALID_IMAGE_URL)\n", + "train_dataset = prepare_dataset(\n", + " train_img_cap,\n", + " TRAIN_IMAGE_PREFIX,\n", + " TRAIN_IMAGE_DIR,\n", + " TRAIN_IMAGE_URL,\n", + " shuffle=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "g11BA6ycJAru" + }, + "source": [ + "## Define models" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jj2aulT90vgO" + }, + "source": [ + "The image encoder and text encoder may not output the embeddings with the same amount of dimensions. We need to project them into the same embedding space" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "k6tSQPkQBfht" + }, + "outputs": [], + "source": [ + "def project_embeddings(embeddings, num_projection_layers, projection_dims,\n", + " dropout_rate):\n", + "\n", + " projected_embeddings = layers.Dense(units=projection_dims)(embeddings)\n", + " for _ in range(num_projection_layers):\n", + " x = tf.nn.relu(projected_embeddings)\n", + " x = layers.Dense(projection_dims)(x)\n", + " x = layers.Dropout(dropout_rate)(x)\n", + " x = layers.Add()([projected_embeddings, x])\n", + " projected_embeddings = layers.LayerNormalization()(x)\n", + "\n", + " # Finally we L2 normalize the embeddings. In general, L2 normalized embeddings\n", + " # are easier to retrieve.\n", + " projected_embeddings = tf.math.l2_normalize(projected_embeddings, axis=1)\n", + " return projected_embeddings" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "U64G7g3pq5bH" + }, + "outputs": [], + "source": [ + "def create_image_encoder(num_projection_layers,\n", + " projection_dims,\n", + " dropout_rate,\n", + " trainable=False):\n", + " efficient_net = hub.KerasLayer(EFFICIENT_NET_URL, trainable=trainable)\n", + " inputs = layers.Input(shape=IMAGE_SIZE + (3,), name='image_input')\n", + " embeddings = efficient_net(inputs)\n", + " outputs = project_embeddings(embeddings, num_projection_layers,\n", + " projection_dims, dropout_rate)\n", + " return keras.Model(inputs, outputs, name='image_encoder')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ObnLD9KM0uy3" + }, + "source": [ + "We use [Universal Sentence Encoder](https://tfhub.dev/google/universal-sentence-encoder-lite/2), a SOTA sentence embedding model, as the text encoder base model. The TFHub lite version is a TF1 saved model. To make it work well in TF2 and later TFLite conversion, we create two models, one is the frozen universal sentence encoder, and the other is the trainable projection layer." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "eJ0MLxdKcFb1" + }, + "outputs": [], + "source": [ + "def create_text_encoder():\n", + " encoder = hub.KerasLayer(\n", + " UNIVERSAL_SENTENCE_ENCODER_URL,\n", + " name='universal_sentence_encoder',\n", + " signature='default')\n", + " encoder.trainable = False\n", + " inputs = layers.Input(\n", + " shape=(None,), dtype=tf.int64, name='text_input', sparse=True)\n", + " embeddings = encoder(\n", + " dict(\n", + " values=inputs.values,\n", + " indices=inputs.indices,\n", + " dense_shape=inputs.dense_shape))\n", + " return keras.Model(inputs, embeddings, name='text_encoder')\n", + "\n", + "\n", + "def create_text_embedder_projection(input_dim, num_projection_layers,\n", + " projection_dims, dropout_rate):\n", + " inputs = layers.Input(shape=(input_dim), dtype=tf.float32, name='text_input')\n", + " outputs = project_embeddings(inputs, num_projection_layers, projection_dims,\n", + " dropout_rate)\n", + " return keras.Model(inputs, outputs, name='projection_layers')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "yHX9RYZ62ZmC" + }, + "source": [ + "This dual encoder model is derived from this [Keras post](https://keras.io/examples/nlp/nl_image_search/)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "2jVV2cHKCWIm" + }, + "outputs": [], + "source": [ + "class DualEncoder(keras.Model):\n", + "\n", + " def __init__(self,\n", + " text_encoder,\n", + " text_encoder_projection,\n", + " image_encoder,\n", + " temperature,\n", + " **kwargs):\n", + " super(DualEncoder, self).__init__(**kwargs)\n", + " self.text_encoder = text_encoder\n", + " self.text_encoder_projection = text_encoder_projection\n", + " self.image_encoder = image_encoder\n", + "\n", + " # Temperature controls the contrast of softmax output. In general, a low\n", + " # temperature increases the contrast and a high temperature decreases it.\n", + " self.temperature = temperature\n", + " self.loss_tracker = keras.metrics.Mean(name='loss')\n", + "\n", + " @property\n", + " def metrics(self):\n", + " return [self.loss_tracker]\n", + "\n", + " def call(self, features, training=False):\n", + " # If there are two GPUs present, we use one of them for image encoder and\n", + " # one for text encoder. If there's only one GPU then they will be trained on\n", + " # the same GPU.\n", + " with tf.device('/gpu:0'):\n", + " caption_embeddings = self.text_encoder(\n", + " features['caption'], training=False)\n", + " caption_embeddings = self.text_encoder_projection(\n", + " caption_embeddings, training=training)\n", + " with tf.device('/gpu:1'):\n", + " image_embeddings = self.image_encoder(\n", + " features['image'], training=training)\n", + " return caption_embeddings, image_embeddings\n", + "\n", + " def compute_loss(self, caption_embeddings, image_embeddings):\n", + " # Computing the loss with dot product similarity between image and text\n", + " # embeddings.\n", + " logits = (\n", + " tf.matmul(caption_embeddings, image_embeddings, transpose_b=True) /\n", + " self.temperature)\n", + " images_similarity = tf.matmul(\n", + " image_embeddings, image_embeddings, transpose_b=True)\n", + " captions_similarity = tf.matmul(\n", + " caption_embeddings, caption_embeddings, transpose_b=True)\n", + "\n", + " # The targets is the mean of the self-similarity of the captions and images.\n", + " # This is more lenient to the similar examples appeared in the same batch.\n", + " targets = keras.activations.softmax(\n", + " (captions_similarity + images_similarity) / (2 * self.temperature))\n", + " captions_loss = keras.losses.categorical_crossentropy(\n", + " y_true=targets, y_pred=logits, from_logits=True)\n", + " images_loss = keras.losses.categorical_crossentropy(\n", + " y_true=tf.transpose(targets),\n", + " y_pred=tf.transpose(logits),\n", + " from_logits=True)\n", + " return (captions_loss + images_loss) / 2\n", + "\n", + " def train_step(self, features):\n", + " with tf.GradientTape() as tape:\n", + " # Forward pass\n", + " caption_embeddings, image_embeddings = self(features, training=True)\n", + " loss = self.compute_loss(caption_embeddings, image_embeddings)\n", + "\n", + " # Backward pass\n", + " gradients = tape.gradient(loss, self.trainable_variables)\n", + " self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))\n", + " self.loss_tracker.update_state(loss)\n", + " return {'loss': self.loss_tracker.result()}\n", + "\n", + " def test_step(self, features):\n", + " caption_embeddings, image_embeddings = self(features, training=False)\n", + " loss = self.compute_loss(caption_embeddings, image_embeddings)\n", + " self.loss_tracker.update_state(loss)\n", + " return {'loss': self.loss_tracker.result()}" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "9POw0Ye4x-XR" + }, + "source": [ + "## Train the Dual Encoder model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "y9Wz75GfxN6L" + }, + "source": [ + "Load the models from Tensorflow Hub." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "JHsZiXEK_ZwO" + }, + "outputs": [], + "source": [ + "# The text embedder consists of two models. One is the frozen base universal\n", + "# sentence encoder, and the other is the trainable projection layer. We are\n", + "# doing this instead of one model to make later TFLite model conversion easier.\n", + "text_encoder = create_text_encoder()\n", + "projection_layers = create_text_embedder_projection(\n", + " input_dim=512, # Universal sentence encoder output has 512 dimensions\n", + " num_projection_layers=1,\n", + " projection_dims=EMB_SIZE,\n", + " dropout_rate=0.1)\n", + "\n", + "image_encoder = create_image_encoder(\n", + " num_projection_layers=1, projection_dims=EMB_SIZE, dropout_rate=0.1)\n", + "\n", + "dual_encoder = DualEncoder(\n", + " text_encoder, projection_layers, image_encoder, temperature=0.05)\n", + "dual_encoder.compile(\n", + " optimizer=tfa.optimizers.AdamW(learning_rate=0.001, weight_decay=0.001))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Tj8v8wq6xUbS" + }, + "source": [ + "Train the dual encoder model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "executionInfo": { + "elapsed": 12338131, + "status": "ok", + "timestamp": 1651898372226, + "user": { + "displayName": "Zonglin Li", + "userId": "11843710831668693042" + }, + "user_tz": 240 + }, + "id": "Q1a4h5DNCaBq", + "outputId": "62b2a90b-fcf3-4e5f-fff6-63ed1047b57a" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1/10\n", + "324/324 [==============================] - 1387s 4s/step - loss: 1.8785 - val_loss: 1.5041 - lr: 0.0010\n", + "Epoch 2/10\n", + "324/324 [==============================] - 1345s 4s/step - loss: 1.4041 - val_loss: 1.3767 - lr: 0.0010\n", + "Epoch 3/10\n", + "324/324 [==============================] - 1351s 4s/step - loss: 1.3275 - val_loss: 1.3518 - lr: 0.0010\n", + "Epoch 4/10\n", + "324/324 [==============================] - 1364s 4s/step - loss: 1.2792 - val_loss: 1.3365 - lr: 9.0484e-04\n", + "Epoch 5/10\n", + "324/324 [==============================] - 1353s 4s/step - loss: 1.2511 - val_loss: 1.3124 - lr: 8.1873e-04\n", + "Epoch 6/10\n", + "324/324 [==============================] - 1352s 4s/step - loss: 1.2366 - val_loss: 1.2991 - lr: 7.4082e-04\n", + "Epoch 7/10\n", + "324/324 [==============================] - 1359s 4s/step - loss: 1.2266 - val_loss: 1.2935 - lr: 6.7032e-04\n", + "Epoch 8/10\n", + "324/324 [==============================] - 1354s 4s/step - loss: 1.2154 - val_loss: 1.3117 - lr: 6.0653e-04\n", + "Epoch 9/10\n", + "324/324 [==============================] - 1359s 4s/step - loss: 1.2220 - val_loss: 1.3212 - lr: 5.4881e-04\n", + "Training completed. Saving image and text encoders.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:absl:Found untraced functions such as restored_function_body, restored_function_body, restored_function_body, restored_function_body, restored_function_body while saving (showing 5 of 622). These functions will not be directly callable after loading.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Models are saved.\n" + ] + } + ], + "source": [ + "# We train the first three epochs with the learning rate of 0.001 and\n", + "# decrease it exponentially later on.\n", + "def lr_scheduler(epoch, lr):\n", + " if epoch \u003c 3:\n", + " return lr\n", + " else:\n", + " return max(lr * tf.math.exp(-0.1), lr * 0.1)\n", + "\n", + "# In colab, training takes roughly 4s per step, around 24 mins per epoch\n", + "early_stopping = tf.keras.callbacks.EarlyStopping(\n", + " monitor='val_loss', patience=2, restore_best_weights=True)\n", + "history = dual_encoder.fit(\n", + " train_dataset,\n", + " epochs=NUM_EPOCHS,\n", + " validation_data=valid_dataset,\n", + " callbacks=[\n", + " tf.keras.callbacks.LearningRateScheduler(lr_scheduler), early_stopping\n", + " ],\n", + " max_queue_size=2,\n", + ")\n", + "\n", + "# Save the models. We are not going to save the text_encoder since it's frozen\n", + "# and the TF2 saved model for text_encoder has problems converting to TFLite.\n", + "print('Training completed. Saving image and text encoders.')\n", + "dual_encoder.image_encoder.save('image_encoder')\n", + "dual_encoder.text_encoder_projection.save('text_encoder_projection')\n", + "print('Models are saved.')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Hcyej5mYxX39" + }, + "source": [ + "## Create the text-to-image Searcher model using Model Maker" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Bp0qBKkyu4jA" + }, + "source": [ + "### Generate image embeddings" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dXdecbiY2NSs" + }, + "source": [ + "Load the valid and train dataset one more time. This time we are not going to shuffle the train split and we use the whole validataion split. Since images are not loaded until they are iterated, creating the datasets should be cheap." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "PA_X283yMUsR" + }, + "outputs": [], + "source": [ + "combined_valid_dataset = prepare_dataset(\n", + " valid_img_cap,\n", + " VALID_IMAGE_PREFIX,\n", + " VALID_IMAGE_DIR,\n", + " VALID_IMAGE_URL)\n", + "deterministic_train_dataset = prepare_dataset(\n", + " train_img_cap,\n", + " TRAIN_IMAGE_PREFIX,\n", + " TRAIN_IMAGE_DIR,\n", + " TRAIN_IMAGE_URL)\n", + "\n", + "all_combined = deterministic_train_dataset.concatenate(combined_valid_dataset)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "lK8FdFcx2siR" + }, + "source": [ + "Create the metadata (image file names and the flickr post id) from the dataset. This will later be packed into the TFLite model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "M98I3IaHRBEl" + }, + "outputs": [], + "source": [ + "def create_metadata(image_file_prefix, image_dir):\n", + "\n", + " def _create_metadata(image_info):\n", + " # This is the same way we generated the image paths in the prepare_dataset\n", + " # function above\n", + " coco_id = image_info[0]\n", + " flickr_id = image_info[1][1]\n", + " return ('%s_%s' %\n", + " (flickr_id,\n", + " os.path.join(image_dir, '%s%012d.jpg' %\n", + " (image_file_prefix, coco_id)))).encode('utf-8')\n", + "\n", + " return _create_metadata\n", + "\n", + "\n", + "# We don't store the images in the index file, as that would be too big. We only\n", + "# store the image path and the corresponding Flickr id.\n", + "metadata = list(\n", + " map(create_metadata(TRAIN_IMAGE_PREFIX, TRAIN_IMAGE_DIR), train_img_cap))\n", + "metadata.extend(\n", + " map(create_metadata(VALID_IMAGE_PREFIX, VALID_IMAGE_DIR), valid_img_cap))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "EYUk_D2S24Fg" + }, + "source": [ + "Generate the embeddings for all the images we have. We do it in Tensorflow with GPU instead of Model Maker. Again, these will be packed into the TFLite model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "executionInfo": { + "elapsed": 1147631, + "status": "ok", + "timestamp": 1651899528619, + "user": { + "displayName": "Zonglin Li", + "userId": "11843710831668693042" + }, + "user_tz": 240 + }, + "id": "Vk--b8EgQhHo", + "outputId": "e4ea27c3-175e-43bd-d40e-a2668a8c9298" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "483/483 [==============================] - 1147s 2s/step\n", + "Embedding matrix shape: (123287, 128)\n" + ] + } + ], + "source": [ + "# Image encoder takes one input named `image_input` so we remove other values in\n", + "# the dataset.\n", + "image_dataset = all_combined.map(\n", + " lambda example: {'image_input': example['image']})\n", + "image_embeddings = dual_encoder.image_encoder.predict(image_dataset, verbose=1)\n", + "print(f'Embedding matrix shape: {image_embeddings.shape}')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6Dzye66Xc8vE" + }, + "source": [ + "### Convert text embedder to TFLite" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "IBef6gzm3AIQ" + }, + "source": [ + "We need to convert the saved model to TF1 as the base Universal Sentence Encoder is a TF1 model. It'll create a saved model dir on disk called `converted_model`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "executionInfo": { + "elapsed": 10521, + "status": "ok", + "timestamp": 1651899539127, + "user": { + "displayName": "Zonglin Li", + "userId": "11843710831668693042" + }, + "user_tz": 240 + }, + "id": "jJV-44C0c_FK", + "outputId": "ef8ac6ee-7d65-470d-b6de-897df5a466af" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model saved to converted_model/\n" + ] + } + ], + "source": [ + "#@title Prepare the saved model\n", + "!rm -rf converted_model\n", + "\n", + "# This create a new TF1 SavedModel from 1). The tfhub USE, and 2). The\n", + "# projection layers trained and saved from TF2.\n", + "with tf1.Graph().as_default() as g:\n", + " with tf1.Session() as sess:\n", + " # Reload the Universal Sentence Encoder model from tfhub. We can't just save\n", + " # the USE in TF2 as we did for the projection layers as that causes issues\n", + " # in the TFLite converter.\n", + " module = hub.Module(UNIVERSAL_SENTENCE_ENCODER_URL)\n", + " spm_path = sess.run(module(signature='spm_path'))\n", + " with tf1.io.gfile.GFile(spm_path, mode='rb') as f:\n", + " serialized_spm = f.read()\n", + " spm_path = sess.run(module(signature='spm_path'))\n", + " input_str = tf1.placeholder(dtype=tf1.string, shape=[None])\n", + " tokenizer = sentencepiece_tokenizer.FastSentencepieceTokenizer(\n", + " model=serialized_spm)\n", + " tokenized = tf1.sparse.from_dense(tokenizer.tokenize(input_str).to_tensor())\n", + " tokenized = tf1.cast(tokenized, dtype=tf1.int64)\n", + " encodings = module(\n", + " inputs=dict(\n", + " values=tokenized.values,\n", + " indices=tokenized.indices,\n", + " dense_shape=tokenized.dense_shape))\n", + "\n", + " # Then combine it with the trained projection layers\n", + " projection_layers = tf1.keras.models.load_model('text_encoder_projection')\n", + " encodings = projection_layers(encodings)\n", + "\n", + " sess.run([tf1.global_variables_initializer(), tf1.tables_initializer()])\n", + "\n", + " # Save with SavedModelBuilder\n", + " builder = tf1.saved_model.Builder('converted_model')\n", + " sig_def = tf1.saved_model.predict_signature_def(\n", + " inputs={'input': input_str}, outputs={'output': encodings})\n", + " builder.add_meta_graph_and_variables(\n", + " sess,\n", + " tags=['serve'],\n", + " signature_def_map={\n", + " tf1.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY: sig_def\n", + " },\n", + " clear_devices=True)\n", + " builder.save()\n", + "print('Model saved to converted_model/')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "XeS_H13j3KY_" + }, + "source": [ + "Convert and save the TFLite model. Here the model only has the text encoder. We will add in the retrieval index in the following steps." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "DPGs2kxbdGtK" + }, + "outputs": [], + "source": [ + "converter = tf.lite.TFLiteConverter.from_saved_model('converted_model')\n", + "converter.experimental_new_converter = True\n", + "converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]\n", + "converter.allow_custom_ops = True\n", + "converted_model_tflite = converter.convert()\n", + "with open('text_embedder.tflite', 'wb') as f:\n", + " f.write(converted_model_tflite)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7gqnxcsDCTeq" + }, + "source": [ + "### Create TFLite Searcher model\n", + "\n", + "In general see the documentation of [`ScaNNOptions`](https://www.tensorflow.org/lite/api_docs/python/tflite_model_maker/searcher/ScaNNOptions) for how to configure the searcher for your dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "z6peW6vvxMnF" + }, + "outputs": [], + "source": [ + "import tflite_model_maker as mm" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "bi6bkdnNVAiB" + }, + "outputs": [], + "source": [ + "scann_options = mm.searcher.ScaNNOptions(\n", + " # We use the dot product similarity as this is how the model is trained\n", + " distance_measure='dot_product',\n", + " # Enable space partitioning with K-Means tree\n", + " tree=mm.searcher.Tree(\n", + " # How many partitions to have. A rule of thumb is the square root of the\n", + " # dataset size. In this case it's 351.\n", + " num_leaves=int(math.sqrt(len(metadata))),\n", + " # Searching 4 partitions seems to give reasonable result. Searching more\n", + " # will definitely return better results, but it's more costly to run.\n", + " num_leaves_to_search=4),\n", + " # Compress each float to int8 in the embedding. See\n", + " # https://www.tensorflow.org/lite/api_docs/python/tflite_model_maker/searcher/ScoreAH\n", + " # for details\n", + " score_ah=mm.searcher.ScoreAH(\n", + " # Using 1 dimension per quantization block.\n", + " 1,\n", + " # Generally 0.2 works pretty well.\n", + " anisotropic_quantization_threshold=0.2))\n", + "\n", + "data = mm.searcher.DataLoader(\n", + " embedder_path='text_embedder.tflite',\n", + " dataset=image_embeddings,\n", + " metadata=metadata)\n", + "\n", + "model = mm.searcher.Searcher.create_from_data(\n", + " data=data, scann_options=scann_options)\n", + "model.export(\n", + " export_filename='searcher_model.tflite',\n", + " userinfo='',\n", + " export_format=mm.searcher.ExportFormat.TFLITE)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "EeZwqEnxW5Xl" + }, + "source": [ + "## Run inference using Task Library" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "z-gkyy7vXRS0" + }, + "outputs": [], + "source": [ + "from tflite_support.task import text\n", + "from tflite_support.task import core" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZQXqwY_X3eP4" + }, + "source": [ + "Configure the searcher to return 6 results per query and not to L2 normalize the query embeddings because the text encoder has already normalized them. See [source code](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/python/task/text/text_searcher.py) on how to configure the `TextSearcher`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "nJlkXpDsW8_5" + }, + "outputs": [], + "source": [ + "options = text.TextSearcherOptions(\n", + " base_options=core.BaseOptions(\n", + " file_name='searcher_model.tflite'))\n", + "\n", + "# The searcher returns 6 results\n", + "options.search_options.max_results = 6\n", + "\n", + "tflite_searcher = text.TextSearcher.create_from_options(options)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ggwCRyT_kQGs" + }, + "outputs": [], + "source": [ + "def search_image_with_text(query_str, show_images=False):\n", + " neighbors = tflite_searcher.search(query_str)\n", + "\n", + " for i, neighbor in enumerate(neighbors.nearest_neighbors):\n", + " metadata = neighbor.metadata.decode('utf-8').split('_')\n", + " flickr_id = metadata[0]\n", + " print('Flickr url for %d: http://flickr.com/photo.gne?id=%s' %\n", + " (i + 1, flickr_id))\n", + "\n", + " if show_images:\n", + " plt.figure(figsize=(20, 13))\n", + " for i, neighbor in enumerate(neighbors.nearest_neighbors):\n", + " ax = plt.subplot(2, 3, i + 1)\n", + "\n", + " # Using negative distance since on-device ScaNN returns negative\n", + " # dot-product distance.\n", + " ax.set_title('%d: Similarity: %.05f' % (i + 1, -neighbor.distance))\n", + " metadata = neighbor.metadata.decode('utf-8').split('_')\n", + " image_path = '_'.join(metadata[1:])\n", + " image = tf.image.decode_jpeg(\n", + " tf.io.read_file(image_path), channels=3) / 255\n", + " plt.imshow(image)\n", + " plt.axis('off')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "JGAsS4mQ3dnX" + }, + "source": [ + "We will not show the image here due to copyright issues. You can set `show_images=True` to display them (note that you can't set it to `True` unless you've downloaded the images at [this cell](#scrollTo=Ke6EeKAqj1vB\u0026line=12\u0026uniqifier=1)). Please check the post links for the license of each image." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "executionInfo": { + "elapsed": 172, + "status": "ok", + "timestamp": 1651934792149, + "user": { + "displayName": "Zonglin Li", + "userId": "11843710831668693042" + }, + "user_tz": 240 + }, + "id": "v7g0RmYjks9i", + "outputId": "18a154fd-884d-4ad1-9498-416691831758" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Flickr url for 1: http://flickr.com/photo.gne?id=6388219123\n", + "Flickr url for 2: http://flickr.com/photo.gne?id=30100145\n", + "Flickr url for 3: http://flickr.com/photo.gne?id=3322126404\n", + "Flickr url for 4: http://flickr.com/photo.gne?id=4945223078\n", + "Flickr url for 5: http://flickr.com/photo.gne?id=120446248\n", + "Flickr url for 6: http://flickr.com/photo.gne?id=4807048033\n" + ] + } + ], + "source": [ + "search_image_with_text('A man riding on a bike')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "9w2hEIF768be" + }, + "source": [ + "Congratulations on finishing this colab! For next steps, you can try deploy the model on-device (inference + search on Pixel 6 is around 6 ms), or you can train the model with your own dataset. In the mean time, don't forget to checkout our documentations ([Model Maker](https://www.tensorflow.org/lite/guide/model_maker/), [Task Library](https://www.tensorflow.org/lite/inference_with_metadata/task_library/text_searcher/)) and the [reference app](https://github.com/tensorflow/examples/tree/master/lite/examples/text_searcher/android), which searches news articles in [CNN_DailyMail dataset](https://www.tensorflow.org/datasets/catalog/cnn_dailymail)" + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "name": "On-device Text-to-Image Search with TensorFlow Lite Searcher Library", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/BUILD index 233f02c..ec9a845 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/BUILD
@@ -1,3 +1,5 @@ +load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite") + package( default_visibility = [ "//tensorflow_lite_support:internal", @@ -85,17 +87,19 @@ }), ) -cc_library( +cc_library_with_tflite( name = "universal_sentence_encoder_qa_op_resolver", srcs = ["universal_sentence_encoder_qa_op_resolver.cc"], hdrs = ["universal_sentence_encoder_qa_op_resolver.h"], + tflite_deps = [ + "@org_tensorflow//tensorflow/lite/core/shims:builtin_ops", + ], deps = [ "//tensorflow_lite_support/custom_ops/kernel/ragged:ragged_tensor_to_tensor_tflite", # fixdeps: keep "//tensorflow_lite_support/custom_ops/kernel/sentencepiece:sentencepiece_tokenizer_tflite", # fixdeps: keep "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@org_tensorflow//tensorflow/lite:op_resolver", - "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", ], ) @@ -118,3 +122,56 @@ "@com_google_absl//absl/strings", ], ) + +# Example usage: +# bazel run -c opt \ +# tensorflow_lite_support/examples/task/text/desktop:text_searcher_demo \ +# -- \ +# --model_path=/path/to/model.tflite \ +# --index_path=/path/to/index.ldb \ +# --input_sentence="your_input" +cc_binary( + name = "text_searcher_demo", + srcs = ["text_searcher_demo.cc"], + deps = [ + ":universal_sentence_encoder_qa_op_resolver", + "//tensorflow_lite_support/cc/port:configuration_proto_inc", + "//tensorflow_lite_support/cc/port:status_macros", + "//tensorflow_lite_support/cc/task/core/proto:base_options_cc_proto", + "//tensorflow_lite_support/cc/task/core/proto:external_file_proto_inc", + "//tensorflow_lite_support/cc/task/processor/proto:embedding_options_cc_proto", + "//tensorflow_lite_support/cc/task/processor/proto:search_options_cc_proto", + "//tensorflow_lite_support/cc/task/processor/proto:search_result_cc_proto", + "//tensorflow_lite_support/cc/task/text:text_searcher", + "//tensorflow_lite_support/cc/task/text/proto:text_searcher_options_cc_proto", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/flags:parse", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", + ], +) + +# Example usage: +# bazel run -c opt \ +# tensorflow_lite_support/examples/task/text/desktop:text_embedder_demo \ +# -- \ +# --model_path=/path/to/model.tflite \ +# --first_sentence="first sentence" \ +# --second_sentence="second sentence" +cc_binary( + name = "text_embedder_demo", + srcs = ["text_embedder_demo.cc"], + deps = [ + ":universal_sentence_encoder_qa_op_resolver", + "//tensorflow_lite_support/cc/port:configuration_proto_inc", + "//tensorflow_lite_support/cc/port:status_macros", + "//tensorflow_lite_support/cc/task/core/proto:base_options_cc_proto", + "//tensorflow_lite_support/cc/task/processor/proto:embedding_cc_proto", + "//tensorflow_lite_support/cc/task/text:text_embedder", + "//tensorflow_lite_support/cc/task/text/proto:text_embedder_options_cc_proto", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/flags:parse", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/README.md b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/README.md index 25528820..247a30f 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/README.md +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/README.md
@@ -93,8 +93,6 @@ #### Prerequisites -TODO(b/163086702): Update the links to models with metadata attached. - You will need: * a Bert based TFLite text classification model from model maker. (e.g. @@ -162,9 +160,109 @@ Output answers 1: Paris is the capital of France. Score: 5.63752 ``` +## TextEmbedder + +#### Prerequisites + +You will need: + +* a TFLite text embedder model such as the universal sentence encoder QA model + from [TensorFlow Hub][6]. + +#### Usage + +The TextEmbedder demo tool takes two sentences as inputs, and outputs the +[cosine similarity][7] between their embeddings. + +In the console, run: + +```bash +# Download the model: +curl \ + -L 'https://tfhub.dev/google/lite-model/universal-sentence-encoder-qa-ondevice/1?lite-format=tflite' \ + -o /tmp/universal_sentence_encoder_qa_with_metadata.tflite + +# Run the embedder tool: +bazel run -c opt \ +tensorflow_lite_support/examples/task/text/desktop:text_embedder_demo -- \ +--model_path=/tmp/universal_sentence_encoder_qa_with_metadata.tflite \ +--l2_normalize \ +--first_sentence="It was a very sunny day." \ +--second_sentence="The sun was shining brightly." +``` + +#### Results + +In the console, you should get: + +``` +Cosine similarity: 0.952549 +``` + +## TextSearcher + +#### Prerequisites + +You will need: + +* a TFLite text embedder model such as the universal sentence encoder QA model + from [TensorFlow Hub][6], +* an index built from that embedder model using [Model Maker][8]. + +Model Maker also provides the ability to add the index directly to the embedder +model metadata. The demo also supports this : just omit the `--index_path` +argument. + +#### Usage + +In this example, we'll be using a test index built from the universal sentence +encoder QA model, which only contains 5 embeddings. + +In the console, run: + +```bash +# Download the model: +curl \ + -L 'https://tfhub.dev/google/lite-model/universal-sentence-encoder-qa-ondevice/1?lite-format=tflite' \ + -o /tmp/universal_sentence_encoder_qa_with_metadata.tflite + +# Run the searcher tool: +bazel run -c opt \ +tensorflow_lite_support/examples/task/text/desktop:text_searcher_demo -- \ +--model_path=/tmp/universal_sentence_encoder_qa_with_metadata.tflite \ +--l2_normalize \ +--index_path=$(pwd)/third_party/tensorflow_lite_support/cc/test/testdata/task/text/universal_sentence_encoder_index.ldb \ +--input_sentence="The sun was very bright." +``` + +#### Results + +In the console, you should get: + +``` +Results: + Rank#0: + metadata: The sun was shining on that day. + distance: 0.04618 + Rank#1: + metadata: It was a sunny day. + distance: 0.10856 + Rank#2: + metadata: The weather was excellent. + distance: 0.15223 + Rank#3: + metadata: The cat is chasing after the mouse. + distance: 0.34271 + Rank#4: + metadata: He was very happy with his newly bought car. + distance: 0.37703 +``` + [1]: https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1 [2]: https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1 [3]: https://www.tensorflow.org/lite/models/text_classification/overview [4]: https://github.com/tensorflow/tflite-support/blob/fe8b69002f5416900285dc69e2baa078c91bd994/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h#L55 [5]: http://bert/nl/classifier/model [6]: https://tfhub.dev/google/lite-model/universal-sentence-encoder-qa-ondevice/1 +[7]: https://en.wikipedia.org/wiki/Cosine_similarity +[8]: https://www.tensorflow.org/lite/api_docs/python/tflite_model_maker/searcher
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/text_embedder_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/text_embedder_demo.cc new file mode 100644 index 0000000..eca8a00 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/text_embedder_demo.cc
@@ -0,0 +1,152 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Example usage: +// bazel run -c opt \ +// tensorflow_lite_support/examples/task/text/desktop:text_embedder_demo \ +// -- \ +// --model_path=/path/to/model.tflite \ +// --first_sentence="first sentence" \ +// --second_sentence="second sentence" + +#include <iostream> +#include <memory> + +#include "absl/flags/flag.h" // from @com_google_absl +#include "absl/flags/parse.h" // from @com_google_absl +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/str_format.h" // from @com_google_absl +#include "tensorflow_lite_support/cc/port/configuration_proto_inc.h" +#include "tensorflow_lite_support/cc/port/status_macros.h" +#include "tensorflow_lite_support/cc/task/core/proto/base_options.pb.h" +#include "tensorflow_lite_support/cc/task/processor/proto/embedding.pb.h" +#include "tensorflow_lite_support/cc/task/text/proto/text_embedder_options.pb.h" +#include "tensorflow_lite_support/cc/task/text/text_embedder.h" +#include "tensorflow_lite_support/examples/task/text/desktop/universal_sentence_encoder_qa_op_resolver.h" + +ABSL_FLAG(std::string, + model_path, + "", + "Absolute path to the '.tflite' text embedder model."); +ABSL_FLAG(std::string, + first_sentence, + "", + "First sentence, whose feature vector will be extracted and compared " + "to the second sentence using cosine similarity."); +ABSL_FLAG(std::string, + second_sentence, + "", + "Second sentence, whose feature vector will be extracted and " + "compared to the first sentence using cosine similarity."); +ABSL_FLAG(bool, + l2_normalize, + false, + "If true, the raw feature vectors returned by the image embedder " + "will be normalized with L2-norm. Generally only needed if the model " + "doesn't already contain a L2_NORMALIZATION TFLite Op."); +ABSL_FLAG(bool, + use_coral, + false, + "If true, inference will be delegated to a connected Coral Edge TPU " + "device."); + +namespace tflite { +namespace task { +namespace text { + +namespace { +using std::chrono::microseconds; +using std::chrono::steady_clock; +} // namespace + +TextEmbedderOptions BuildOptions() { + TextEmbedderOptions options; + options.mutable_base_options()->mutable_model_file()->set_file_name( + absl::GetFlag(FLAGS_model_path)); + if (absl::GetFlag(FLAGS_l2_normalize)) { + options.add_embedding_options()->set_l2_normalize(true); + } + if (absl::GetFlag(FLAGS_use_coral)) { + options.mutable_base_options() + ->mutable_compute_settings() + ->mutable_tflite_settings() + ->set_delegate(::tflite::proto::Delegate::EDGETPU_CORAL); + } + return options; +} + +absl::Status ComputeCosineSimilarity() { + // Build TextEmbedder. + const TextEmbedderOptions options = BuildOptions(); + ASSIGN_OR_RETURN( + std::unique_ptr<TextEmbedder> text_embedder, + TextEmbedder::CreateFromOptions(options, CreateQACustomOpResolver())); + + // Run search and display results. + auto start_embed = steady_clock::now(); + ASSIGN_OR_RETURN(processor::EmbeddingResult first_embedding, + text_embedder->Embed(absl::GetFlag(FLAGS_first_sentence))); + auto end_embed = steady_clock::now(); + std::string delegate = + absl::GetFlag(FLAGS_use_coral) ? "Coral Edge TPU" : "CPU"; + std::cout << "Time cost to compute embedding for first sentence on " + << delegate << ": " + << std::chrono::duration<float, std::milli>(end_embed - start_embed) + .count() + << " ms" << std::endl; + + ASSIGN_OR_RETURN(processor::EmbeddingResult second_embedding, + text_embedder->Embed(absl::GetFlag(FLAGS_second_sentence))); + // Compute cosine similarity. + ASSIGN_OR_RETURN(double cosine_similarity, + TextEmbedder::CosineSimilarity( + first_embedding.embeddings(0).feature_vector(), + second_embedding.embeddings(0).feature_vector())); + + // Display result. + std::cout << absl::StrFormat("Cosine similarity: %f\n", cosine_similarity); + return absl::OkStatus(); +} + +} // namespace text +} // namespace task +} // namespace tflite + +int main(int argc, char** argv) { + // Parse command line and perform sanity checks. + absl::ParseCommandLine(argc, argv); + if (absl::GetFlag(FLAGS_model_path).empty()) { + std::cerr << "Missing mandatory 'model_path' argument.\n"; + return 1; + } + if (absl::GetFlag(FLAGS_first_sentence).empty()) { + std::cerr << "Missing mandatory 'first_sentence' argument.\n"; + return 1; + } + if (absl::GetFlag(FLAGS_second_sentence).empty()) { + std::cerr << "Missing mandatory 'second_sentence' argument.\n"; + return 1; + } + + // Run search. + absl::Status status = tflite::task::text::ComputeCosineSimilarity(); + if (status.ok()) { + return 0; + } else { + std::cerr << "Cosine similarity computation failed: " << status.message() + << "\n"; + return 1; + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/text_searcher_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/text_searcher_demo.cc new file mode 100644 index 0000000..02994289 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/text_searcher_demo.cc
@@ -0,0 +1,162 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Example usage: +// bazel run -c opt \ +// tensorflow_lite_support/examples/task/text/desktop:text_searcher_demo \ +// -- \ +// --model_path=/path/to/model.tflite \ +// --index_path=/path/to/index.ldb \ +// --input_sentence="your_input" + +#include <iostream> +#include <memory> + +#include "absl/flags/flag.h" // from @com_google_absl +#include "absl/flags/parse.h" // from @com_google_absl +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/str_format.h" // from @com_google_absl +#include "tensorflow_lite_support/cc/port/configuration_proto_inc.h" +#include "tensorflow_lite_support/cc/port/status_macros.h" +#include "tensorflow_lite_support/cc/task/core/proto/base_options.pb.h" +#include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h" +#include "tensorflow_lite_support/cc/task/processor/proto/embedding_options.pb.h" +#include "tensorflow_lite_support/cc/task/processor/proto/search_options.pb.h" +#include "tensorflow_lite_support/cc/task/processor/proto/search_result.pb.h" +#include "tensorflow_lite_support/cc/task/text/proto/text_searcher_options.pb.h" +#include "tensorflow_lite_support/cc/task/text/text_searcher.h" +#include "tensorflow_lite_support/examples/task/text/desktop/universal_sentence_encoder_qa_op_resolver.h" + +ABSL_FLAG(std::string, + model_path, + "", + "Absolute path to the '.tflite' text embedder model."); +ABSL_FLAG(std::string, + index_path, + "", + "Absolute path to the index to search into. Mandatory only if the " + "index is not attached to the output tensor metadata of the embedder " + "model as an AssociatedFile with type SCANN_INDEX_FILE."); +ABSL_FLAG(std::string, + input_sentence, + "", + "Input sentence whose nearest-neighbors to search for in the index."); +ABSL_FLAG(int32, + max_results, + 5, + "Maximum number of nearest-neghbors to display."); +ABSL_FLAG(bool, + l2_normalize, + false, + "If true, the raw feature vectors returned by the image embedder " + "will be normalized with L2-norm. Generally only needed if the model " + "doesn't already contain a L2_NORMALIZATION TFLite Op."); +ABSL_FLAG(bool, + use_coral, + false, + "If true, inference will be delegated to a connected Coral Edge TPU " + "device."); + +namespace tflite { +namespace task { +namespace text { + +namespace { +using std::chrono::microseconds; +using std::chrono::steady_clock; +} // namespace + +TextSearcherOptions BuildOptions() { + TextSearcherOptions options; + options.mutable_base_options()->mutable_model_file()->set_file_name( + absl::GetFlag(FLAGS_model_path)); + if (absl::GetFlag(FLAGS_l2_normalize)) { + options.mutable_embedding_options()->set_l2_normalize(true); + } + if (!absl::GetFlag(FLAGS_index_path).empty()) { + options.mutable_search_options()->mutable_index_file()->set_file_name( + absl::GetFlag(FLAGS_index_path)); + } + options.mutable_search_options()->set_max_results( + absl::GetFlag(FLAGS_max_results)); + if (absl::GetFlag(FLAGS_use_coral)) { + options.mutable_base_options() + ->mutable_compute_settings() + ->mutable_tflite_settings() + ->set_delegate(::tflite::proto::Delegate::EDGETPU_CORAL); + } + return options; +} + +void DisplayResults(const processor::SearchResult& result) { + std::cout << "Results:\n"; + for (int rank = 0; rank < result.nearest_neighbors_size(); ++rank) { + const auto& neighbor = result.nearest_neighbors(rank); + std::cout << absl::StrFormat(" Rank#%d:\n", rank); + std::cout << absl::StrFormat(" metadata: %s\n", neighbor.metadata()); + std::cout << absl::StrFormat(" distance: %.5f\n", neighbor.distance()); + } +} + +absl::Status Search() { + // Build TextSearcher. + const TextSearcherOptions options = BuildOptions(); + ASSIGN_OR_RETURN( + std::unique_ptr<TextSearcher> text_searcher, + TextSearcher::CreateFromOptions(options, CreateQACustomOpResolver())); + + // Run search and display results. + auto start_search = steady_clock::now(); + ASSIGN_OR_RETURN(processor::SearchResult result, + text_searcher->Search(absl::GetFlag(FLAGS_input_sentence))); + auto end_search = steady_clock::now(); + std::string delegate = + absl::GetFlag(FLAGS_use_coral) ? "Coral Edge TPU" : "CPU"; + std::cout << "Time cost to search the input text on " << delegate << ": " + << std::chrono::duration<float, std::milli>(end_search - + start_search) + .count() + << " ms" << std::endl; + + DisplayResults(result); + + return absl::OkStatus(); +} + +} // namespace text +} // namespace task +} // namespace tflite + +int main(int argc, char** argv) { + // Parse command line and perform sanity checks. + absl::ParseCommandLine(argc, argv); + if (absl::GetFlag(FLAGS_model_path).empty()) { + std::cerr << "Missing mandatory 'model_path' argument.\n"; + return 1; + } + if (absl::GetFlag(FLAGS_input_sentence).empty()) { + std::cerr << "Missing mandatory 'input_sentence' argument.\n"; + return 1; + } + + // Run search. + absl::Status status = tflite::task::text::Search(); + if (status.ok()) { + return 0; + } else { + std::cerr << "Search failed: " << status.message() << "\n"; + return 1; + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/universal_sentence_encoder_qa_op_resolver.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/universal_sentence_encoder_qa_op_resolver.cc index 21e4f26..00dd534 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/universal_sentence_encoder_qa_op_resolver.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/text/desktop/universal_sentence_encoder_qa_op_resolver.cc
@@ -15,7 +15,7 @@ #include "tensorflow_lite_support/examples/task/text/desktop/universal_sentence_encoder_qa_op_resolver.h" #include "absl/memory/memory.h" // from @com_google_absl -#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/core/shims/cc/kernels/register.h" namespace tflite { namespace ops { @@ -32,7 +32,8 @@ // Creates custom op resolver for USE QA task. std::unique_ptr<tflite::OpResolver> CreateQACustomOpResolver() { - auto resolver = absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(); + auto resolver = + absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>(); resolver->AddCustom( "TFSentencepieceTokenizeOp", ::tflite::ops::custom::Register_SENTENCEPIECE_TOKENIZER());
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/BUILD index 45393140..6b1ab5d 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/BUILD
@@ -21,7 +21,7 @@ "//tensorflow_lite_support/cc/task/vision/proto:classifications_proto_inc", "//tensorflow_lite_support/cc/task/vision/proto:image_classifier_options_proto_inc", "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_common_utils", - "//tensorflow_lite_support/examples/task/vision/desktop/utils:image_utils", + "//tensorflow_lite_support/cc/task/vision/utils:image_utils", ] + select({ "//tensorflow_lite_support/examples/task:darwinn_portable": [ "//tensorflow_lite_support/acceleration/configuration:edgetpu_coral_plugin", @@ -49,7 +49,7 @@ "//tensorflow_lite_support/cc/task/vision/proto:detections_proto_inc", "//tensorflow_lite_support/cc/task/vision/proto:object_detector_options_proto_inc", "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_common_utils", - "//tensorflow_lite_support/examples/task/vision/desktop/utils:image_utils", + "//tensorflow_lite_support/cc/task/vision/utils:image_utils", ] + select({ "//tensorflow_lite_support/examples/task:darwinn_portable": [ "//tensorflow_lite_support/acceleration/configuration:edgetpu_coral_plugin", @@ -75,7 +75,7 @@ "//tensorflow_lite_support/cc/task/vision/proto:image_segmenter_options_proto_inc", "//tensorflow_lite_support/cc/task/vision/proto:segmentations_proto_inc", "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_common_utils", - "//tensorflow_lite_support/examples/task/vision/desktop/utils:image_utils", + "//tensorflow_lite_support/cc/task/vision/utils:image_utils", ] + select({ "//tensorflow_lite_support/examples/task:darwinn_portable": [ "//tensorflow_lite_support/acceleration/configuration:edgetpu_coral_plugin", @@ -100,7 +100,7 @@ "//tensorflow_lite_support/cc/task/vision/proto:embeddings_proto_inc", "//tensorflow_lite_support/cc/task/vision/proto:image_embedder_options_proto_inc", "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_common_utils", - "//tensorflow_lite_support/examples/task/vision/desktop/utils:image_utils", + "//tensorflow_lite_support/cc/task/vision/utils:image_utils", ] + select({ "//tensorflow_lite_support/examples/task:darwinn_portable": [ "//tensorflow_lite_support/acceleration/configuration:edgetpu_coral_plugin", @@ -129,7 +129,7 @@ "//tensorflow_lite_support/cc/task/vision/proto:image_searcher_options_cc_proto", "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_common_utils", "//tensorflow_lite_support/cc/task/vision:image_searcher", - "//tensorflow_lite_support/examples/task/vision/desktop/utils:image_utils", + "//tensorflow_lite_support/cc/task/vision/utils:image_utils", ] + select({ "//tensorflow_lite_support/examples/task:darwinn_portable": [ "//tensorflow_lite_support/acceleration/configuration:edgetpu_coral_plugin",
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/README.md b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/README.md index ef7969a..dea0cec0 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/README.md +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/README.md
@@ -259,7 +259,110 @@  +## ImageEmbedder + +#### Prerequisites + +You will need: + +* a TFLite image embedder model (e.g. [mobilenet v3][5], a generic image + embedder trained on ImageNet), +* two PNG, JPEG or GIF image to extract embeddings from. + +#### Usage + +The ImageEmbedder demo tool takes two images as inputs, and outputs the +[cosine similarity][6] between their embeddings. + +In the console, run: + +```bash +# Download the model: +curl \ + -L 'https://tfhub.dev/google/lite-model/imagenet/mobilenet_v3_small_100_224/feature_vector/5/metadata/1?lite-format=tflite' \ + -o /tmp/mobilenet_v3_embedder.tflite + +# Run the embedder tool: +bazel run -c opt \ +tensorflow_lite_support/examples/task/vision/desktop:image_embedder_demo -- \ +--model_path=/tmp/mobilenet_v3_embedder.tflite \ +--l2_normalize \ +--first_image_path=$(pwd)/tensorflow_lite_support/cc/test/testdata/task/vision/burger.jpg \ +--second_image_path=$(pwd)/tensorflow_lite_support/cc/test/testdata/task/vision/burger_crop.jpg +``` + +#### Results + +In the console, you should get: + +``` +Cosine similarity: 0.932738 +``` + +## ImageSearcher + +#### Prerequisites + +You will need: + +* a TFLite image embedder model (e.g. [mobilenet v3][5], a generic image + embedder trained on ImageNet), +* an index built from that embedder model using [Model Maker][7]. + +Model Maker also provides the ability to add the index directly to the embedder +model metadata. The demo also supports this : just omit the `--index_path` +argument. + +#### Usage + +In this example, we'll be using a test index built from the mobilenet v3 +embedder model, which only contains 5 embeddings extracted from images of a +burger, a cat, a dog, a bird and a car. + +In the console, run: + +```bash +# Download the model: +curl \ + -L 'https://tfhub.dev/google/lite-model/imagenet/mobilenet_v3_small_100_224/feature_vector/5/metadata/1?lite-format=tflite' \ + -o /tmp/mobilenet_v3_embedder.tflite + +# Run the searcher tool: +bazel run -c opt \ +tensorflow_lite_support/examples/task/vision/desktop:image_searcher_demo -- \ +--model_path=/tmp/mobilenet_v3_embedder.tflite \ +--l2_normalize \ +--index_path=$(pwd)/third_party/tensorflow_lite_support/cc/test/testdata/task/vision/searcher_index.ldb \ +--image_path=$(pwd)/third_party/tensorflow_lite_support/cc/test/testdata/task/vision/burger_crop.jpg +``` + +#### Results + +In the console, you should get: + +``` +Results: + Rank#0: + metadata: burger + distance: 0.13452 + Rank#1: + metadata: car + distance: 1.81935 + Rank#2: + metadata: bird + distance: 1.96617 + Rank#3: + metadata: dog + distance: 2.05610 + Rank#4: + metadata: cat + distance: 2.06347 +``` + [1]: https://tfhub.dev/google/lite-model/aiy/vision/classifier/birds_V1/3 [2]: https://tfhub.dev/tensorflow/lite-model/ssd_mobilenet_v1/1/metadata/2 [3]: https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/2 [4]: https://coral.ai/docs/edgetpu/inference/ +[5]: https://tfhub.dev/google/lite-model/imagenet/mobilenet_v3_small_100_224/feature_vector/5/metadata/1 +[6]: https://en.wikipedia.org/wiki/Cosine_similarity +[7]: https://www.tensorflow.org/lite/api_docs/python/tflite_model_maker/searcher
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_classifier_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_classifier_demo.cc index bd2aaaf1..0904920 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_classifier_demo.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_classifier_demo.cc
@@ -34,7 +34,7 @@ #include "tensorflow_lite_support/cc/task/vision/proto/classifications_proto_inc.h" #include "tensorflow_lite_support/cc/task/vision/proto/image_classifier_options_proto_inc.h" #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h" -#include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h" +#include "tensorflow_lite_support/cc/task/vision/utils/image_utils.h" ABSL_FLAG(std::string, model_path,
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_embedder_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_embedder_demo.cc index 040878a..f8b1796 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_embedder_demo.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_embedder_demo.cc
@@ -37,7 +37,7 @@ #include "tensorflow_lite_support/cc/task/vision/proto/embeddings_proto_inc.h" #include "tensorflow_lite_support/cc/task/vision/proto/image_embedder_options_proto_inc.h" #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h" -#include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h" +#include "tensorflow_lite_support/cc/task/vision/utils/image_utils.h" ABSL_FLAG(std::string, model_path,
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_searcher_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_searcher_demo.cc index a188b5d..e4074f7 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_searcher_demo.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_searcher_demo.cc
@@ -40,7 +40,7 @@ #include "tensorflow_lite_support/cc/task/vision/image_searcher.h" #include "tensorflow_lite_support/cc/task/vision/proto/image_searcher_options.pb.h" #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h" -#include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h" +#include "tensorflow_lite_support/cc/task/vision/utils/image_utils.h" ABSL_FLAG(std::string, model_path, @@ -49,7 +49,9 @@ ABSL_FLAG(std::string, index_path, "", - "Absolute path to the index to search into."); + "Absolute path to the index to search into. Mandatory only if the " + "index is not attached to the output tensor metadata of the embedder " + "model as an AssociatedFile with type SCANN_INDEX_FILE."); ABSL_FLAG(std::string, image_path, "", @@ -57,9 +59,9 @@ "RGBA (grayscale is not supported). The image EXIF orientation " "flag, if any, is NOT taken into account."); ABSL_FLAG(int32, - num_results, + max_results, 5, - "Number of nearest-neighbor results to display."); + "Maximum number of nearest-neighbor results to display."); ABSL_FLAG(bool, l2_normalize, false, @@ -88,10 +90,12 @@ if (absl::GetFlag(FLAGS_l2_normalize)) { options.mutable_embedding_options()->set_l2_normalize(true); } - options.mutable_search_options()->mutable_index_file()->set_file_name( - absl::GetFlag(FLAGS_index_path)); - options.mutable_search_options()->set_num_results( - absl::GetFlag(FLAGS_num_results)); + if (!absl::GetFlag(FLAGS_index_path).empty()) { + options.mutable_search_options()->mutable_index_file()->set_file_name( + absl::GetFlag(FLAGS_index_path)); + } + options.mutable_search_options()->set_max_results( + absl::GetFlag(FLAGS_max_results)); if (absl::GetFlag(FLAGS_use_coral)) { options.mutable_base_options() ->mutable_compute_settings() @@ -164,10 +168,6 @@ std::cerr << "Missing mandatory 'model_path' argument.\n"; return 1; } - if (absl::GetFlag(FLAGS_index_path).empty()) { - std::cerr << "Missing mandatory 'index_path' argument.\n"; - return 1; - } if (absl::GetFlag(FLAGS_image_path).empty()) { std::cerr << "Missing mandatory 'image_path' argument.\n"; return 1;
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_segmenter_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_segmenter_demo.cc index 6487fe9..fdc7872 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_segmenter_demo.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/image_segmenter_demo.cc
@@ -35,7 +35,7 @@ #include "tensorflow_lite_support/cc/task/vision/proto/image_segmenter_options_proto_inc.h" #include "tensorflow_lite_support/cc/task/vision/proto/segmentations_proto_inc.h" #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h" -#include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h" +#include "tensorflow_lite_support/cc/task/vision/utils/image_utils.h" ABSL_FLAG(std::string, model_path,
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/object_detector_demo.cc b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/object_detector_demo.cc index 9208439..fd000fc 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/object_detector_demo.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/object_detector_demo.cc
@@ -38,7 +38,7 @@ #include "tensorflow_lite_support/cc/task/vision/proto/detections_proto_inc.h" #include "tensorflow_lite_support/cc/task/vision/proto/object_detector_options_proto_inc.h" #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h" -#include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h" +#include "tensorflow_lite_support/cc/task/vision/utils/image_utils.h" ABSL_FLAG(std::string, model_path,
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/utils/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/utils/BUILD deleted file mode 100644 index b3009dc..0000000 --- a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/utils/BUILD +++ /dev/null
@@ -1,27 +0,0 @@ -package( - default_visibility = [ - "//tensorflow_lite_support:internal", - ], - licenses = ["notice"], # Apache 2.0 -) - -cc_library( - name = "image_utils", - srcs = ["image_utils.cc"], - hdrs = ["image_utils.h"], - visibility = [ - "//tensorflow_lite_support:internal", - ], - deps = [ - "//tensorflow_lite_support/cc/port:integral_types", - "//tensorflow_lite_support/cc/port:status_macros", - "//tensorflow_lite_support/cc/port:statusor", - "//tensorflow_lite_support/cc/task/vision/core:frame_buffer", - "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_common_utils", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@stblib//:stb_image", - "@stblib//:stb_image_write", - ], -)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/TensorFlowLiteTaskText.podspec.template b/third_party/tflite_support/src/tensorflow_lite_support/ios/TensorFlowLiteTaskText.podspec.template index 63bd9f7b..71b6345 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/TensorFlowLiteTaskText.podspec.template +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/TensorFlowLiteTaskText.podspec.template
@@ -2,7 +2,7 @@ s.name = 'TensorFlowLiteTaskText' s.version = '${TFLS_BUILD_VERSION}' s.authors = 'Google Inc.' - s.license = { :type => 'Apache' } + s.license = { :type => 'Apache',:file => "LICENSE" } s.homepage = 'https://github.com/tensorflow/tflite-support' s.source = { :http => '${TFLS_DOWNLOAD_URL}' } s.summary = 'TensorFlow Lite Task Library - Text' @@ -21,9 +21,9 @@ objc_dir + '{nlclassifier,qa}/Sources/*.h' ] - cc_dir = 'tensorflow_lite_support/cc/task/text/' + c_dir = 'tensorflow_lite_support/c/task/text/' s.source_files = [ - cc_dir + '{nlclassifier,qa}/*_c_api*.h', + c_dir + '*.h', objc_dir + 'apis/*.h', objc_dir + '{nlclassifier,qa}/Sources/*.{h,m,mm}' ] @@ -31,8 +31,7 @@ s.pod_target_xcconfig = { 'HEADER_SEARCH_PATHS' => '"${PODS_TARGET_SRCROOT}" ' + - '"${PODS_TARGET_SRCROOT}/' + cc_dir + 'nlclassifier" ' + - '"${PODS_TARGET_SRCROOT}/' + cc_dir + 'qa" ' + + '"${PODS_TARGET_SRCROOT}/' + c_dir + '" ' + '"${PODS_TARGET_SRCROOT}/' + objc_dir + 'apis" ' + '"${PODS_TARGET_SRCROOT}/' + objc_dir + 'nlclassifier/Sources" ' + '"${PODS_TARGET_SRCROOT}/' + objc_dir + 'qa/Sources"',
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/TensorFlowLiteTaskVision.podspec.template b/third_party/tflite_support/src/tensorflow_lite_support/ios/TensorFlowLiteTaskVision.podspec.template index 79473c4..137d891 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/TensorFlowLiteTaskVision.podspec.template +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/TensorFlowLiteTaskVision.podspec.template
@@ -2,7 +2,7 @@ s.name = 'TensorFlowLiteTaskVision' s.version = '${TFLS_BUILD_VERSION}' s.authors = 'Google Inc.' - s.license = { :type => 'Apache',:file => "LICENSE.txt"} + s.license = { :type => 'Apache',:file => "LICENSE"} s.homepage = 'https://github.com/tensorflow/tflite-support' s.source = { :http => '${TFLS_DOWNLOAD_URL}' } s.summary = 'TensorFlow Lite Task Library - Vision'
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/BUILD new file mode 100644 index 0000000..04dbe26 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/BUILD
@@ -0,0 +1,33 @@ +package( + default_visibility = ["//tensorflow_lite_support:internal"], + licenses = ["notice"], # Apache 2.0 +) + +objc_library( + name = "TFLFloatBuffer", + srcs = [ + "sources/TFLFloatBuffer.m", + ], + hdrs = [ + "sources/TFLFloatBuffer.h", + ], + module_name = "TFLFloatBuffer", + deps = ["//third_party/apple_frameworks:Foundation"], +) + +objc_library( + name = "TFLRingBuffer", + srcs = [ + "sources/TFLRingBuffer.m", + ], + hdrs = [ + "sources/TFLRingBuffer.h", + ], + module_name = "TFLRingBuffer", + deps = [ + ":TFLFloatBuffer", + "//tensorflow_lite_support/ios:TFLCommon", + "//tensorflow_lite_support/ios:TFLCommonUtils", + "//third_party/apple_frameworks:Foundation", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLFloatBuffer.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLFloatBuffer.h new file mode 100644 index 0000000..a5db970 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLFloatBuffer.h
@@ -0,0 +1,56 @@ +// Copyright 2022 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import <Foundation/Foundation.h> + +NS_ASSUME_NONNULL_BEGIN + +/** An wrapper class to store pointer to a float array and its size. */ +@interface TFLFloatBuffer : NSObject <NSCopying> + +/** Capacity of the array in number of elements. */ +@property(nonatomic, readonly) NSUInteger size; + +/** Pointer to float array wrapped by `TFLFloatBuffer`. */ +@property(nonatomic, readonly) float* data; + +/** + * Initializes a new `TFLFloatBuffer` by copying the elements of the given float + * data array. + * + * @param data A pointer to a float data array whose values are to be copied + * into the buffer. + * @param size Size of the array float data array. + * + * @return A new instance of `TFLFloatBuffer` initialized with the elements of + * the given float data array. + */ +- (instancetype)initWithData:(float*)data size:(NSUInteger)size; + +/** + * Initializes a `TFLFloatBuffer` of the specified size with zeros. + * + * @param size Number of elements the `TFLFloatBuffer` can hold. + * + * @return A new instance of `TFLFloatBuffer` of the given size with all + * elements initialized to zero. + */ +- (instancetype)initWithSize:(NSUInteger)size; + +/** Clears the `TFLFloatBuffer` by setting all elements to zero */ +- (void)clear; + +@end + +NS_ASSUME_NONNULL_END
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLFloatBuffer.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLFloatBuffer.m new file mode 100644 index 0000000..d32fc43 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLFloatBuffer.m
@@ -0,0 +1,58 @@ +// Copyright 2022 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "tensorflow_lite_support/ios/task/audio/core/sources/TFLFloatBuffer.h" + +@implementation TFLFloatBuffer + +- (instancetype)initWithData:(float*)data size:(NSUInteger)size { + self = [self init]; + if (self) { + _size = size; + _data = malloc(sizeof(float) * size); + if (!_data) { + exit(-1); + } + if (data) { + memcpy(_data, data, sizeof(float) * size); + } + } + return self; +} + +- (instancetype)initWithSize:(NSUInteger)size { + self = [self init]; + if (self) { + _size = size; + _data = calloc(size, sizeof(float)); + if (!_data) { + exit(-1); + } + } + return self; +} + +- (id)copyWithZone:(NSZone*)zone { + return [[TFLFloatBuffer alloc] initWithData:_data size:_size]; +} + +- (void)clear { + memset(_data, 0, sizeof(float) * _size); +} + +- (void)dealloc { + free(_data); +} + +@end
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLRingBuffer.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLRingBuffer.h new file mode 100644 index 0000000..b300de6b --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLRingBuffer.h
@@ -0,0 +1,79 @@ +// Copyright 2022 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import <Foundation/Foundation.h> +#import "TFLFloatBuffer.h" + +NS_ASSUME_NONNULL_BEGIN + +/** An wrapper class which stores a buffer that is written in circular fashion. + */ +@interface TFLRingBuffer : NSObject + +/** + * A copy of all the internal ring buffer elements in order. + */ +@property(nullable, nonatomic, readonly) TFLFloatBuffer* floatBuffer; + +/** + * Capacity of the ring buffer in number of elements. + */ +@property(nonatomic, readonly) NSUInteger size; + +/** + * Initializes a new `TFLRingBuffer` with the given size. All elements of the + * `TFLRingBuffer` will be initialized to zero. + * + * @param size Size of the ring buffer. + * + * @return A new instance of `TFLRingBuffer` with the given size and all + * elements initialized to zero. + */ +- (instancetype)initWithBufferSize:(NSUInteger)size; + +/** + * Loads a slice of a float array to the ring buffer. If the float array is + * longer than ring buffer's capacity, samples with lower indices in the array + * will be ignored. + * + * @return Boolean indicating success or failure of loading operation. + */ +- (BOOL)loadBuffer:(TFLFloatBuffer*)sourceBuffer + offset:(NSUInteger)offset + size:(NSUInteger)size + error:(NSError**)error; + +/** + * Returns a `TFLFloatBuffer` with a copy of size number of the ring buffer + * elements in order starting at offset, i.e, buffer[offset:offset+size]. + * + * @param offset Offset in the ring buffer from which elements are to be + * returned. + * + * @param size Number of elements to be returned. + * + * @return A new `TFLFloatBuffer` if offset + size is within the bounds of the + * ring buffer, otherwise nil. + */ +- (nullable TFLFloatBuffer*)floatBufferWithOffset:(NSUInteger)offset + size:(NSUInteger)size; + +/** + * Clears the `TFLRingBuffer` by setting all the elements to zero . + */ +- (void)clear; + +@end + +NS_ASSUME_NONNULL_END
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLRingBuffer.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLRingBuffer.m new file mode 100644 index 0000000..5749540 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/audio/core/sources/TFLRingBuffer.m
@@ -0,0 +1,120 @@ +// Copyright 2022 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "tensorflow_lite_support/ios/task/audio/core/sources/TFLRingBuffer.h" +#import "tensorflow_lite_support/ios/sources/TFLCommon.h" +#import "tensorflow_lite_support/ios/sources/TFLCommonUtils.h" + +@implementation TFLRingBuffer { + NSUInteger _nextIndex; + TFLFloatBuffer* _buffer; +} + +- (instancetype)initWithBufferSize:(NSUInteger)size { + self = [self init]; + if (self) { + _buffer = [[TFLFloatBuffer alloc] initWithSize:size]; + } + return self; +} + +- (BOOL)loadBuffer:(TFLFloatBuffer*)sourceBuffer + offset:(NSUInteger)offset + size:(NSUInteger)size + error:(NSError**)error { + NSUInteger sizeToCopy = size; + NSUInteger newOffset = offset; + + if (offset + size > sourceBuffer.size) { + [TFLCommonUtils createCustomError:error + withCode:TFLSupportErrorCodeInvalidArgumentError + description:@"offset + size exceeds the maximum size " + @"of the source buffer."]; + return NO; + } + + // Length is greater than buffer size, then modify size and offset to + // keep most recent data in the sourceBuffer. + if (size >= _buffer.size) { + sizeToCopy = _buffer.size; + newOffset = offset + (size - _buffer.size); + } + + // If the new nextIndex + sizeToCopy is smaller than the size of the ring + // buffer directly copy all elements to the end of the ring buffer. + if (_nextIndex + sizeToCopy < _buffer.size) { + memcpy(_buffer.data + _nextIndex, sourceBuffer.data + newOffset, + sizeof(float) * sizeToCopy); + } else { + NSUInteger endChunkSize = _buffer.size - _nextIndex; + memcpy(_buffer.data + _nextIndex, sourceBuffer.data + newOffset, + sizeof(float) * endChunkSize); + + NSUInteger startChunkSize = sizeToCopy - endChunkSize; + memcpy(_buffer.data, sourceBuffer.data + newOffset + endChunkSize, + sizeof(float) * startChunkSize); + } + + _nextIndex = (_nextIndex + sizeToCopy) % _buffer.size; + + return YES; +} + +- (TFLFloatBuffer*)floatBuffer { + return [self floatBufferWithOffset:0 size:self.size]; +} + +- (nullable TFLFloatBuffer*)floatBufferWithOffset:(NSUInteger)offset + size:(NSUInteger)size { + if (offset + size > _buffer.size) { + return nil; + } + + TFLFloatBuffer* bufferToReturn = [[TFLFloatBuffer alloc] initWithSize:size]; + + // Return buffer in correct order. + // Compute offset in flat ring buffer array considering warping. + NSInteger correctOffset = (_nextIndex + offset) % _buffer.size; + + // If no; elements to be copied are within the end of the flat ring buffer. + if ((correctOffset + size) <= _buffer.size) { + memcpy(bufferToReturn.data, _buffer.data + correctOffset, + sizeof(float) * size); + } else { + // If no; elements to be copied warps around to the beginning of the ring + // buffer. Copy the chunk starting at ringBuffer[nextIndex + offset : size] + // to beginning of the result array. + NSInteger endChunkSize = _buffer.size - correctOffset; + memcpy(bufferToReturn.data, _buffer.data + correctOffset, + sizeof(float) * endChunkSize); + + // Next copy the chunk starting at ringBuffer[0 : size - endChunkSize] to + // the result array. + NSInteger firstChunkSize = size - endChunkSize; + memcpy(bufferToReturn.data + endChunkSize, _buffer.data, + sizeof(float) * firstChunkSize); + } + + return bufferToReturn; +} + +- (void)clear { + [_buffer clear]; +} + +- (NSUInteger)size { + return _buffer.size; +} + +@end
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions.h index 0f92dd1..7ab7e72 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions.h
@@ -24,7 +24,7 @@ * @discussion This property hould be greater than 0 or equal to -1. Setting it * to -1 has the effect to let TFLite runtime set the value. */ -@property(nonatomic, assign) int numThreads; +@property(nonatomic) int numThreads; @end
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/BUILD index 834acfd22..146c049 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/BUILD
@@ -62,6 +62,9 @@ "sources/TFLSegmentationResult.h", ], module_name = "TFLSegmentationResult", + deps = [ + "//tensorflow_lite_support/ios:TFLCommonUtils", + ], ) objc_library(
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLCategory+Helpers.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLCategory+Helpers.m index 7a49281c..41395255 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLCategory+Helpers.m +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLCategory+Helpers.m
@@ -19,8 +19,6 @@ + (TFLCategory *)categoryWithCCategory:(TfLiteCategory *)cCategory { if (cCategory == nil) return nil; - TFLCategory *category = [[TFLCategory alloc] init]; - NSString* displayName; NSString* label;
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLCategory.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLCategory.h index b5b19af5..5c521f22 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLCategory.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLCategory.h
@@ -17,41 +17,47 @@ NS_ASSUME_NONNULL_BEGIN /** Encapsulates information about a class in the classification results. */ +NS_SWIFT_NAME(ClassificationCategory) @interface TFLCategory : NSObject /** Index of the class in the corresponding label map, usually packed in the * TFLite Model Metadata. */ -@property(nonatomic, assign, readonly) NSInteger index; +@property(nonatomic, readonly) NSInteger index; /** Confidence score for this class . */ -@property(nonatomic, assign, readonly) float score; +@property(nonatomic, readonly) float score; /** Class name of the class. */ -@property(nonatomic, copy, readonly, nullable) NSString* label; +@property(nonatomic, readonly, nullable) NSString* label; /** Display name of the class. */ -@property(nonatomic, copy, readonly, nullable) NSString* displayName; +@property(nonatomic, readonly, nullable) NSString* displayName; /** - * Initializes TFLCategory. + * Initializes a new `TFLCategory` with the given index, score, label and + * display name. * * @param index Index of the class in the corresponding label map, usually * packed in the TFLite Model Metadata. * - * @param score Confidence score for this class . + * @param score Confidence score for this class. * * @param label Class name of the class. * * @param displayName Display name of the class. * - * @return An instance of TFLCategory initialized to - * the specified values. + * @return An instance of `TFLCategory` initialized with the given index, score, + * label and display name. */ - (instancetype)initWithIndex:(NSInteger)index score:(float)score label:(nullable NSString*)label displayName:(nullable NSString*)displayName; +- (instancetype)init NS_UNAVAILABLE; + ++ (instancetype)new NS_UNAVAILABLE; + @end NS_ASSUME_NONNULL_END
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions+Helpers.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions+Helpers.h index 33ecd08..152aa33 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions+Helpers.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions+Helpers.h
@@ -21,7 +21,7 @@ - (BOOL)copyToCOptions:(TfLiteClassificationOptions*)cClassificationOptions error:(NSError**)error; -- (void)deleteCStringArraysOfClassificationOptions: +- (void)deleteAllocatedMemoryOfClassificationOptions: (TfLiteClassificationOptions*)cClassificationOptions; @end
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions+Helpers.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions+Helpers.m index 1d554ca..767e5e4 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions+Helpers.m +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions+Helpers.m
@@ -101,8 +101,8 @@ return YES; } -- (void)deleteCStringArraysOfClassificationOptions: - (TfLiteClassificationOptions *)cClassificationOptions { +- (void)deleteAllocatedMemoryOfClassificationOptions: + (TfLiteClassificationOptions*)cClassificationOptions { if (self.labelAllowList) { [TFLClassificationOptions deleteCStringsArray:cClassificationOptions->label_allowlist.list count:cClassificationOptions->label_allowlist.length]; @@ -112,5 +112,8 @@ [TFLClassificationOptions deleteCStringsArray:cClassificationOptions->label_denylist.list count:cClassificationOptions->label_denylist.length]; } + + free(cClassificationOptions->display_names_local); } + @end
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions.h index 8e52068..ce3f5d658 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions.h
@@ -19,6 +19,7 @@ /** * Holds settings for any single classification task. */ +NS_SWIFT_NAME(ClassificationOptions) @interface TFLClassificationOptions : NSObject <NSCopying> /** If set, all classes in this list will be filtered out from the results . */ @@ -32,10 +33,10 @@ @property(nonatomic, copy) NSString* displayNamesLocale; /** Results with score threshold greater than this value are returned . */ -@property(nonatomic, assign) float scoreThreshold; +@property(nonatomic) float scoreThreshold; /** Limit to the number of classes that can be returned in results. */ -@property(nonatomic, assign) NSInteger maxResults; +@property(nonatomic) NSInteger maxResults; @end
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult+Helpers.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult+Helpers.h index 9c8b92c..351e87d 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult+Helpers.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult+Helpers.h
@@ -20,7 +20,7 @@ @interface TFLClassificationResult (Helpers) /** - * Creates and retrurns a TFLClassificationResult from a + * Creates and returns a TFLClassificationResult from a * TfLiteClassificationResult returned by TFLite Task C Library Classification * tasks. *
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult+Helpers.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult+Helpers.m index 1083e60..52e9285 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult+Helpers.m +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult+Helpers.m
@@ -19,7 +19,8 @@ + (TFLClassificationResult *)classificationResultWithCResult: (TfLiteClassificationResult *)cClassificationResult { - if (cClassificationResult == nil) return nil; + if (!cClassificationResult) + return nil; NSMutableArray *classificationHeads = [[NSMutableArray alloc] init]; for (int i = 0; i < cClassificationResult->size; i++) { @@ -29,8 +30,18 @@ TfLiteCategory cCategory = cClassifications.categories[j]; [categories addObject:[TFLCategory categoryWithCCategory:&cCategory]]; } - TFLClassifications* classifications = - [[TFLClassifications alloc] initWithHeadIndex:i categories:categories]; + + NSString* headName = nil; + + if (cClassifications.head_name) { + headName = [NSString stringWithCString:cClassifications.head_name + encoding:NSUTF8StringEncoding]; + } + + TFLClassifications* classifications = [[TFLClassifications alloc] + initWithHeadIndex:cClassifications.head_index + headName:headName + categories:categories]; [classificationHeads addObject:classifications]; }
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult.h index 80b1aafb..052b4f1 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult.h
@@ -19,52 +19,85 @@ /** Encapsulates list of predicted classes (aka labels) for a given image * classifier head. */ +NS_SWIFT_NAME(Classifications) @interface TFLClassifications : NSObject /** * The index of the image classifier head these classes refer to. This is useful * for multi-head models. */ -@property(nonatomic, assign, readonly) NSInteger headIndex; +@property(nonatomic, readonly) NSInteger headIndex; + +/** The name of the classifier head, which is the corresponding tensor metadata + * name. See + * https://github.com/tensorflow/tflite-support/blob/710e323265bfb71fdbdd72b3516e00cff15c0326/tensorflow_lite_support/metadata/metadata_schema.fbs#L545 + * This will always be NULL for the `TFLClassifications` in the + * `TFLClassificationResult` returned by the follwing methods of + * `TFLImageClassifier`. + * 1. -[TFLImageClassifier classifyWithGMLImage:error:] + * 2. -[TFLImageClassifier classifyWithGMLImage:regionOfInterest:error:] + */ +@property(nonatomic, readonly) NSString* headName; /** The array of predicted classes, usually sorted by descending scores * (e.g.from high to low probability). */ -@property(nonatomic, copy, readonly) NSArray<TFLCategory*>* categories; +@property(nonatomic, readonly) NSArray<TFLCategory*>* categories; /** - * Initializes TFLClassifications. + * Initializes a new `TFLClassifications` with the given head index and array of + * categories. head name is initialized to `nil`. * - * @param categories Array of TFLCategory objects encapsulating a list of + * @param headIndex The index of the image classifier head these classes refer + * to. + * @param categories An array of `TFLCategory` objects encapsulating a list of * predictions usually sorted by descending scores (e.g. from high to low * probability). - * @seealso TFLCategory * - * @return An instance of TFLClassifications initialized to - * the specified values. + * @return An instance of `TFLClassifications` initialized with the given head + * index and array of categories. */ - (instancetype)initWithHeadIndex:(NSInteger)headIndex categories:(NSArray<TFLCategory*>*)categories; +/** + * Initializes a new `TFLClassifications` with the given head index, head name + * and array of categories. + * + * @param headIndex The index of the image classifier head these classes refer + * to. + * @param headName The name of the classifier head, which is the corresponding + * tensor metadata name. + * @param categories An array of `TFLCategory` objects encapsulating a list of + * predictions usually sorted by descending scores (e.g. from high to low + * probability). + * + * @return An object of `TFLClassifications` initialized with the given head + * index, head name and array of categories. + */ +- (instancetype)initWithHeadIndex:(NSInteger)headIndex + headName:(nullable NSString*)headName + categories:(NSArray<TFLCategory*>*)categories; + @end /** Encapsulates results of any classification task. */ +NS_SWIFT_NAME(ClassificationResult) @interface TFLClassificationResult : NSObject /** Array of TFLClassifications objects containing image classifier predictions * per image classifier head. */ -@property(nonatomic, copy, readonly) - NSArray<TFLClassifications*>* classifications; +@property(nonatomic, readonly) NSArray<TFLClassifications*>* classifications; /** - * Initializes TFLClassificationResult. + * Initializes a new `TFLClassificationResult` with the given array of + * classifications. * - * @param classifications Array of TFLClassifications objects containing image - * classifier predictions per image classifier head. - * @seealso TFLClassifications + * @param classifications An Aaray of `TFLClassifications` objects containing + * image classifier predictions per image classifier head. * - * @return An instance of TFLClassificationResult initialized to the specified - * values. + * @return An instance of 1TFLClassificationResult1 initialized with the given + * array of classifications. */ - (instancetype)initWithClassifications: (NSArray<TFLClassifications*>*)classifications;
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult.m index b2ab012..0ea2384 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult.m +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLClassificationResult.m
@@ -17,15 +17,22 @@ @implementation TFLClassifications - (instancetype)initWithHeadIndex:(NSInteger)headIndex + headName:(nullable NSString*)headName categories:(NSArray<TFLCategory*>*)categories { self = [super init]; if (self) { _headIndex = headIndex; + _headName = headName; _categories = categories; } return self; } +- (instancetype)initWithHeadIndex:(NSInteger)headIndex + categories:(NSArray<TFLCategory*>*)categories { + return [self initWithHeadIndex:headIndex headName:nil categories:categories]; +} + @end @implementation TFLClassificationResult {
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult+Helpers.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult+Helpers.m index 799bcda..3ae292c 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult+Helpers.m +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult+Helpers.m
@@ -17,9 +17,10 @@ @implementation TFLDetectionResult (Helpers) -+ (TFLDetectionResult *)detectionResultWithCResult: - (TfLiteDetectionResult *)cDetectionResult { - if (cDetectionResult == nil) return nil; ++ (TFLDetectionResult*)detectionResultWithCResult: + (TfLiteDetectionResult*)cDetectionResult { + if (!cDetectionResult) + return nil; NSMutableArray *detections = [[NSMutableArray alloc] init]; for (int i = 0; i < cDetectionResult->size; i++) { @@ -31,16 +32,15 @@ TFLCategory *resultCategory = [TFLCategory categoryWithCCategory:&cCategory]; [categories addObject:resultCategory]; } - TFLDetection *detection = [[TFLDetection alloc] init]; - detection.categories = categories; - detection.boundingBox = - CGRectMake(cDetection.bounding_box.origin_x, cDetection.bounding_box.origin_y, - cDetection.bounding_box.width, cDetection.bounding_box.height); + TFLDetection* detection = [[TFLDetection alloc] + initWithBoundingBox:CGRectMake(cDetection.bounding_box.origin_x, + cDetection.bounding_box.origin_y, + cDetection.bounding_box.width, + cDetection.bounding_box.height) + categories:categories]; [detections addObject:detection]; } - TFLDetectionResult *detectionResult = [[TFLDetectionResult alloc] init]; - detectionResult.detections = detections; - return detectionResult; + return [[TFLDetectionResult alloc] initWithDetections:detections]; } @end
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult.h index 4d7fc0c..00cc75bb 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult.h
@@ -21,24 +21,59 @@ /** Encapsulates list of predicted classes (aka labels) and bounding box for a * detected object. */ +NS_SWIFT_NAME(Detection) @interface TFLDetection : NSObject /** * The index of the image classifier head these classes refer to. This is useful * for multi-head models. */ -@property(nonatomic, assign) CGRect boundingBox; +@property(nonatomic, readonly) CGRect boundingBox; /** The array of predicted classes, usually sorted by descending scores * (e.g.from high to low probability). */ -@property(nonatomic, copy) NSArray<TFLCategory*>* categories; +@property(nonatomic, readonly) NSArray<TFLCategory*>* categories; + +/** + * Initializes an object of `TFLDetection` with the given bounding box and array + * of categories. + * + * @param boundingBox CGRect specifying the bounds of the object represented by + * this detection. + * @param categories Array of predicted classes, usually sorted by descending + * scores (e.g.from high to low probability). + * + * @return An instance of `TFLDetection` initialized with the given bounding box + * and array of categories. + */ +- (instancetype)initWithBoundingBox:(CGRect)boundingBox + categories:(NSArray<TFLCategory*>*)categories; + +- (instancetype)init NS_UNAVAILABLE; + ++ (instancetype)new NS_UNAVAILABLE; @end /** Encapsulates results of any object detection task. */ +NS_SWIFT_NAME(DetectionResult) @interface TFLDetectionResult : NSObject -@property(nonatomic, copy) NSArray<TFLDetection*>* detections; +@property(nonatomic, readonly) NSArray<TFLDetection*>* detections; + +/** + * Initializes a new `TFLDetectionResult` with the given array of detections. + * + * @param detections Array of detected objects of type TFLDetection. + * + * @return An instance of `TFLDetectionResult` initialized with the given array + * of detections. + */ +- (instancetype)initWithDetections:(NSArray<TFLDetection*>*)detections; + +- (instancetype)init NS_UNAVAILABLE; + ++ (instancetype)new NS_UNAVAILABLE; @end
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult.m index 2b7d3a1..14cec3bca 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult.m +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLDetectionResult.m
@@ -16,13 +16,26 @@ @implementation TFLDetection -@synthesize boundingBox; -@synthesize categories; +- (instancetype)initWithBoundingBox:(CGRect)boundingBox + categories:(NSArray<TFLCategory*>*)categories { + self = [super init]; + if (self) { + _boundingBox = boundingBox; + _categories = categories; + } + return self; +} @end @implementation TFLDetectionResult -@synthesize detections; +- (instancetype)initWithDetections:(NSArray<TFLDetection*>*)detections { + self = [super init]; + if (self) { + _detections = detections; + } + return self; +} @end
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult+Helpers.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult+Helpers.m index b531e78d..2a897f0 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult+Helpers.m +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult+Helpers.m
@@ -18,7 +18,7 @@ + (TFLSegmentationResult*)segmentationResultWithCResult: (TfLiteSegmentationResult*)cSegmentationResult { - if (cSegmentationResult == nil) + if (!cSegmentationResult) return nil; NSMutableArray* segmentations = [[NSMutableArray alloc] init]; @@ -28,27 +28,28 @@ for (int j = 0; j < cSegmentation.colored_labels_size; j++) { TfLiteColoredLabel cColoredLabel = cSegmentation.colored_labels[j]; - TFLColoredLabel* coloredLabel = [[TFLColoredLabel alloc] init]; - coloredLabel.r = (NSUInteger)cColoredLabel.r; - coloredLabel.g = (NSUInteger)cColoredLabel.g; - coloredLabel.b = (NSUInteger)cColoredLabel.b; - - if (cColoredLabel.display_name != nil) { - coloredLabel.displayName = - [NSString stringWithCString:cColoredLabel.display_name - encoding:NSUTF8StringEncoding]; + NSString* displayName; + if (cColoredLabel.display_name) { + displayName = [NSString stringWithCString:cColoredLabel.display_name + encoding:NSUTF8StringEncoding]; } - if (cColoredLabel.label != nil) { - coloredLabel.label = [NSString stringWithCString:cColoredLabel.label - encoding:NSUTF8StringEncoding]; + NSString* label; + if (cColoredLabel.label) { + label = [NSString stringWithCString:cColoredLabel.label + encoding:NSUTF8StringEncoding]; } + TFLColoredLabel* coloredLabel = + [[TFLColoredLabel alloc] initWithRed:(NSUInteger)cColoredLabel.r + green:(NSUInteger)cColoredLabel.g + blue:(NSUInteger)cColoredLabel.b + label:label + displayName:displayName]; [coloredLabels addObject:coloredLabel]; } - TFLSegmentation* segmentation = [[TFLSegmentation alloc] init]; - segmentation.coloredLabels = coloredLabels; + TFLSegmentation* segmentation; if (cSegmentation.confidence_masks) { NSMutableArray* confidenceMasks = [[NSMutableArray alloc] init]; @@ -59,21 +60,24 @@ mask:cSegmentation.confidence_masks[i]]; [confidenceMasks addObject:confidenceMask]; } - segmentation.confidenceMasks = confidenceMasks; + segmentation = + [[TFLSegmentation alloc] initWithConfidenceMasks:confidenceMasks + coloredLabels:coloredLabels]; } else if (cSegmentation.category_mask) { - segmentation.categoryMask = + TFLCategoryMask* categoryMask = [[TFLCategoryMask alloc] initWithWidth:(NSInteger)cSegmentation.width height:(NSInteger)cSegmentation.height mask:cSegmentation.category_mask]; + segmentation = + [[TFLSegmentation alloc] initWithCategoryMask:categoryMask + coloredLabels:coloredLabels]; } [segmentations addObject:segmentation]; } - TFLSegmentationResult* segmentationResult = - [[TFLSegmentationResult alloc] init]; - segmentationResult.segmentations = segmentations; - return segmentationResult; + return [[TFLSegmentationResult alloc] initWithSegmentations:segmentations]; } + @end
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult.h index 49abd3f2..3aca456 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult.h
@@ -17,6 +17,7 @@ NS_ASSUME_NONNULL_BEGIN /** Holds a confidence mask belonging to a single class and its meta data. */ +NS_SWIFT_NAME(ConfidenceMask) @interface TFLConfidenceMask : NSObject /** @@ -31,21 +32,26 @@ @property(nonatomic, readonly) NSInteger width; /** - * The height of the mask. This is an intrinsic parameter of the model being + * The height of the mask. This is an intrinsic parameter of the model being * used, and does not depend on the input image dimensions. */ @property(nonatomic, readonly) NSInteger height; /** - * Initializes a confidence mask. + * Initializes a confidence mask. */ - (instancetype)initWithWidth:(NSInteger)width height:(NSInteger)height mask:(float* _Nullable)mask; +- (instancetype)init NS_UNAVAILABLE; + ++ (instancetype)new NS_UNAVAILABLE; + @end /** Holds category mask and its metadata. */ +NS_SWIFT_NAME(CategoryMask) @interface TFLCategoryMask : NSObject /** @@ -62,60 +68,100 @@ @property(nonatomic, readonly) NSInteger width; /** - * The height of the mask. This is an intrinsic parameter of the model being + * The height of the mask. This is an intrinsic parameter of the model being * used, and does not depend on the input image dimensions. */ @property(nonatomic, readonly) NSInteger height; ++ (instancetype)new NS_UNAVAILABLE; + /** - * Initializes a category mask. + * Initializes a new `TFLCategoryMask` mask. + * + * @param width Width of the mask. + * @param height Height of the mask. + * @param mask Flattened 2D-array of size `width` x `height`, in row major + * order. The value of each pixel in this mask represents the class to which the + * pixel belongs. + * + * @return An instance of TFLCategoryMask initialized to the specified values. */ - (instancetype)initWithWidth:(NSInteger)width height:(NSInteger)height mask:(UInt8* _Nullable)mask; +- (instancetype)init NS_UNAVAILABLE; + @end /** Holds a label associated with an RGB color, for display purposes. */ +NS_SWIFT_NAME(ColoredLabel) @interface TFLColoredLabel : NSObject /** The RGB color components for the label, in the [0, 255] range. */ -@property(nonatomic, assign) NSUInteger r; -@property(nonatomic, assign) NSUInteger g; -@property(nonatomic, assign) NSUInteger b; +@property(nonatomic, readonly) NSUInteger r; +@property(nonatomic, readonly) NSUInteger g; +@property(nonatomic, readonly) NSUInteger b; -/** The class name, as provided in the label map packed in the TFLite Model +/** + * The class name, as provided in the label map packed in the TFLite Model * Metadata. */ -@property(nonatomic, copy) NSString* label; +@property(nonatomic, readonly) NSString* label; -/** The display name, as provided in the label map (if available) packed in - * the TFLite Model Metadata. See `display_names_locale` field in - * ImageSegmenterOptions. +/** + * The display name, as provided in the label map (if available) packed in + * the TFLite Model Metadata. See displayNamesLocale in + * TFLClassificationOptions. */ -@property(nonatomic, copy) NSString* displayName; +@property(nonatomic, readonly) NSString* displayName; + +/** + * Initializes a new `TFLColoredLabel` with red, gree, blue color components, + * label and display name. + * + * @param r Red component of the RGB color components. + * @param g Green component of the RGB color components. + * @param b Blue component of the RGB color components. + * @param label Class name. + * @param displayName Display name. + * + * @return An instance of TFLColoredLabel initialized with red, gree, blue color + * components, label and display name. + */ +- (instancetype)initWithRed:(NSUInteger)r + green:(NSUInteger)g + blue:(NSUInteger)b + label:(NSString*)label + displayName:(NSString*)displayName; + +- (instancetype)init NS_UNAVAILABLE; + ++ (instancetype)new NS_UNAVAILABLE; @end /** Encapsulates a resulting segmentation mask and associated metadata. */ +NS_SWIFT_NAME(Segmentation) @interface TFLSegmentation : NSObject /** - * Array of confidence masks where each element is a confidence mask of size + * Array of confidence masks where each element is a confidence mask of size * `width` x `height`, one for each of the supported classes. * The value of each pixel in these masks represents the confidence score for * this particular class. * This property is mutually exclusive with `categoryMask`. */ -@property(nonatomic, strong, nullable) +@property(nonatomic, nullable, readonly) NSArray<TFLConfidenceMask*>* confidenceMasks; -/** Holds the category mask. +/** + * Holds the category mask. * The value of each pixel in this mask represents the class to which the * pixel belongs. * This property is mutually exclusive with `confidenceMasks`. */ -@property(nonatomic, strong, nullable) TFLCategoryMask* categoryMask; +@property(nonatomic, nullable, readonly) TFLCategoryMask* categoryMask; /** * The list of colored labels for all the supported categories (classes). @@ -124,11 +170,45 @@ * `colored_labels[i]`, `confidence_masks` indices, i.e. `confidence_masks[i]` * is associated with `colored_labels[i]`. */ -@property(nonatomic, strong) NSArray<TFLColoredLabel*>* coloredLabels; +@property(nonatomic, readonly) NSArray<TFLColoredLabel*>* coloredLabels; + ++ (instancetype)new NS_UNAVAILABLE; + +/** + * Initializes a new `TFLSegmentation` with an array of confidence masks and an + * array of colored labels. `categoryMask` is initialized to `nil` as it is + * mutually exclusive with `confidenceMasks`. + * + * @param confidenceMasks An array of `TFLConfidenceMask` objects. + * @param coloredLabels An array of `TFLColoredLabel` objects. + * + * @return An instance of `TFLSegmentation` initialized with an array of + * confidence masks and an array of colored labels. + */ +- (instancetype) + initWithConfidenceMasks:(NSArray<TFLConfidenceMask*>*)confidenceMasks + coloredLabels:(NSArray<TFLColoredLabel*>*)coloredLabels; + +/** + * Initializes a new `TFLSegmentation` with a category mask and array of colored + * labels. `confidenceMasks` is initialized to `nil` as it is mutually exclusive + * with `categoryMask`. + * + * @param categoryMask A `TFLCategoryMask` object. + * @param coloredLabels An array of `TFLColoredLabel` objects. + * + * @return An instance of `TFLSegmentation` initialized with a category mask and + * array of colored labels. + */ +- (instancetype)initWithCategoryMask:(TFLCategoryMask*)categoryMask + coloredLabels:(NSArray<TFLColoredLabel*>*)coloredLabels; + +- (instancetype)init NS_UNAVAILABLE; @end /** Encapsulates results of any image segmentation task. */ +NS_SWIFT_NAME(SegmentationResult) @interface TFLSegmentationResult : NSObject /** Array of segmentations returned after inference by model. @@ -137,7 +217,21 @@ * e.g. instance segmentation models, which may return one segmentation per * object. */ -@property(nonatomic, strong) NSArray<TFLSegmentation*>* segmentations; +@property(nonatomic, readonly) NSArray<TFLSegmentation*>* segmentations; + ++ (instancetype)new NS_UNAVAILABLE; + +/** + * Initializes a new `TFLSegmentationResult` with an array of segmentations. + * + * @param segmentations An array of `TFLSegmentation` objects. + * + * @return An instance of `TFLSegmentationResult` initialized with an array of + * segmentations. + */ +- (instancetype)initWithSegmentations:(NSArray<TFLSegmentation*>*)segmentations; + +- (instancetype)init NS_UNAVAILABLE; @end
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult.m index bf0e75c8..45b5510 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult.m +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult.m
@@ -13,12 +13,9 @@ limitations under the License. ==============================================================================*/ #import "tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult.h" +#import "tensorflow_lite_support/ios/sources/TFLCommonUtils.h" -@implementation TFLCategoryMask { - NSInteger _width; - NSInteger _height; - UInt8* _mask; -} +@implementation TFLCategoryMask - (instancetype)initWithWidth:(NSInteger)width height:(NSInteger)height @@ -28,8 +25,11 @@ _width = width; _height = height; if (mask != NULL) { - _mask = malloc(width * height * sizeof(UInt8)); - memcpy(_mask, mask, width * height * sizeof(UInt8)); + _mask = [TFLCommonUtils mallocWithSize:width * height * sizeof(UInt8) + error:nil]; + if (_mask) { + memcpy(_mask, mask, width * height * sizeof(UInt8)); + } } } return self; @@ -47,11 +47,7 @@ @end -@implementation TFLConfidenceMask { - NSInteger _width; - NSInteger _height; - float* _mask; -} +@implementation TFLConfidenceMask - (instancetype)initWithWidth:(NSInteger)width height:(NSInteger)height @@ -61,8 +57,11 @@ _width = width; _height = height; if (mask != NULL) { - _mask = malloc(width * height * sizeof(float)); - memcpy(_mask, mask, width * height * sizeof(float)); + _mask = [TFLCommonUtils mallocWithSize:width * height * sizeof(float) + error:nil]; + if (_mask) { + memcpy(_mask, mask, width * height * sizeof(float)); + } } } return self; @@ -81,22 +80,66 @@ @end @implementation TFLColoredLabel -@synthesize r; -@synthesize g; -@synthesize b; -@synthesize label; -@synthesize displayName; + +- (instancetype)initWithRed:(NSUInteger)r + green:(NSUInteger)g + blue:(NSUInteger)b + label:(NSString*)label + displayName:(NSString*)displayName { + self = [super init]; + if (self) { + _r = r; + _g = g; + _b = b; + _label = label; + _displayName = displayName; + } + return self; +} @end @implementation TFLSegmentation -@synthesize confidenceMasks; -@synthesize categoryMask; -@synthesize coloredLabels; + +- (instancetype) + initWithConfidenceMasks:(NSArray<TFLConfidenceMask*>*)confidenceMasks + coloredLabels:(NSArray<TFLColoredLabel*>*)coloredLabels { + return [self initWithConfidenceMasks:confidenceMasks + categoryMask:nil + coloredLabels:coloredLabels]; +} + +- (instancetype)initWithCategoryMask:(TFLCategoryMask*)categoryMask + coloredLabels:(NSArray<TFLColoredLabel*>*)coloredLabels { + return [self initWithConfidenceMasks:nil + categoryMask:categoryMask + coloredLabels:coloredLabels]; +} + +- (instancetype) + initWithConfidenceMasks:(NSArray<TFLConfidenceMask*>*)confidenceMasks + categoryMask:(TFLCategoryMask*)categoryMask + coloredLabels:(NSArray<TFLColoredLabel*>*)coloredLabels { + self = [super init]; + if (self) { + _confidenceMasks = confidenceMasks; + _categoryMask = categoryMask; + _coloredLabels = coloredLabels; + } + return self; +} @end @implementation TFLSegmentationResult -@synthesize segmentations; +- (instancetype)initWithSegmentations: + (NSArray<TFLSegmentation*>*)segmentations { + self = [super init]; + if (self) { + _segmentations = segmentations; + } + + return self; +} @end
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.m index e9d3b3d..8af88a7 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.m +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.m
@@ -67,7 +67,7 @@ [ret setValue:[NSNumber numberWithDouble:cCategory.score] forKey:[NSString stringWithUTF8String:cCategory.text]]; } - NLClassifierCategoriesDelete(cCategories); + TfLiteNLClassifierCategoriesDelete(cCategories); return ret; } @end
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLNLClassifier.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLNLClassifier.m index 8d21a111..a68b7ed 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLNLClassifier.m +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLNLClassifier.m
@@ -70,7 +70,7 @@ [ret setValue:[NSNumber numberWithDouble:cCategory.score] forKey:[NSString stringWithUTF8String:cCategory.text]]; } - NLClassifierCategoriesDelete(cCategories); + TfLiteNLClassifierCategoriesDelete(cCategories); return ret; } @end
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/BUILD index 36d97c6..0a756dc5 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/BUILD
@@ -16,10 +16,14 @@ deps = [ "//tensorflow_lite_support/c/task/vision:image_classifier", "//tensorflow_lite_support/ios:TFLCommonUtils", + "//tensorflow_lite_support/ios/task/core:TFLBaseOptions", "//tensorflow_lite_support/ios/task/core:TFLBaseOptionsHelpers", + "//tensorflow_lite_support/ios/task/processor:TFLClassificationOptions", "//tensorflow_lite_support/ios/task/processor:TFLClassificationOptionsHelpers", + "//tensorflow_lite_support/ios/task/processor:TFLClassificationResult", "//tensorflow_lite_support/ios/task/processor:TFLClassificationResultHelpers", "//tensorflow_lite_support/ios/task/vision/utils:GMLImageUtils", + "//tensorflow_lite_support/odml/ios/image:MLImage", ], ) @@ -36,10 +40,14 @@ deps = [ "//tensorflow_lite_support/c/task/vision:object_detector", "//tensorflow_lite_support/ios:TFLCommonUtils", + "//tensorflow_lite_support/ios/task/core:TFLBaseOptions", "//tensorflow_lite_support/ios/task/core:TFLBaseOptionsHelpers", + "//tensorflow_lite_support/ios/task/processor:TFLClassificationOptions", "//tensorflow_lite_support/ios/task/processor:TFLClassificationOptionsHelpers", + "//tensorflow_lite_support/ios/task/processor:TFLDetectionResult", "//tensorflow_lite_support/ios/task/processor:TFLDetectionResultHelpers", "//tensorflow_lite_support/ios/task/vision/utils:GMLImageUtils", + "//tensorflow_lite_support/odml/ios/image:MLImage", ], ) @@ -56,8 +64,11 @@ deps = [ "//tensorflow_lite_support/c/task/vision:image_segmenter", "//tensorflow_lite_support/ios:TFLCommonUtils", + "//tensorflow_lite_support/ios/task/core:TFLBaseOptions", "//tensorflow_lite_support/ios/task/core:TFLBaseOptionsHelpers", + "//tensorflow_lite_support/ios/task/processor:TFLSegmentationResult", "//tensorflow_lite_support/ios/task/processor:TFLSegmentationResultHelpers", "//tensorflow_lite_support/ios/task/vision/utils:GMLImageUtils", + "//tensorflow_lite_support/odml/ios/image:MLImage", ], )
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageClassifier.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageClassifier.h index 5befd57..7e38abe 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageClassifier.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageClassifier.h
@@ -24,103 +24,116 @@ /** * Options to configure TFLImageClassifier. */ +NS_SWIFT_NAME(ImageClassifierOptions) @interface TFLImageClassifierOptions : NSObject /** - * Base options that is used for creation of any type of task. - * @seealso TFLBaseOptions + * Base options that are used for creation of any type of task. + * @discussion Please see `TFLBaseOptions` for more details. */ @property(nonatomic, copy) TFLBaseOptions* baseOptions; /** * Options that configure the display and filtering of results. - * @seealso TFLClassificationOptions + * @discussion Please see `TFLClassificationOptions` for more details. */ @property(nonatomic, copy) TFLClassificationOptions* classificationOptions; /** - * Initializes TFLImageClassifierOptions with the model path set to the - * specified path to a model file. - * @description The external model file, must be a single standalone TFLite - * file. It could be packed with TFLite Model Metadata[1] and associated files - * if exist. Fail to provide the necessary metadata and associated files might + * Initializes a new `TFLImageClassifierOptions` with the absolute path to the + * model file stored locally on the device, set to the given the model path. + * + * @discussion The external model file, must be a single standalone TFLite file. + * It could be packed with TFLite Model Metadata[1] and associated files if + * exist. Fail to provide the necessary metadata and associated files might * result in errors. Check the [documentation] * (https://www.tensorflow.org/lite/convert/metadata) for each task about the * specific requirement. * - * @param modelPath Path to a TFLite model file. - * @return An instance of TFLImageClassifierOptions set to the specified - * modelPath. + * @param modelPath An absolute path to a TensorFlow Lite model file stored + * locally on the device. + * + * @return An instance of `TFLImageClassifierOptions` initialized to the given + * model path. */ -- (nullable instancetype)initWithModelPath:(NSString*)modelPath; - -- (instancetype)init NS_UNAVAILABLE; - -+ (instancetype)new NS_UNAVAILABLE; +- (instancetype)initWithModelPath:(NSString*)modelPath; @end /** * A TensorFlow Lite Task Image Classifiier. */ +NS_SWIFT_NAME(ImageClassifier) @interface TFLImageClassifier : NSObject /** - * Creates TFLImageClassifier from a model file and specified options . + * Creates a new instance of `TFLImageClassifier` from the given + * `TFLImageClassifierOptions`. * - * @param options TFLImageClassifierOptions instance with the necessary - * properties set. + * @param options The options to use for configuring the `TFLImageClassifier`. + * @param error An optional error parameter populated when there is an error in + * initializing the image classifier. * - * @return A TFLImageClassifier instance. + * @return A new instance of `TFLImageClassifier` with the given options. `nil` + * if there is an error in initializing the image classifier. */ + (nullable instancetype)imageClassifierWithOptions: (TFLImageClassifierOptions*)options error:(NSError**)error - NS_SWIFT_NAME(imageClassifier(options:)); + NS_SWIFT_NAME(classifier(options:)); + ++ (instancetype)new NS_UNAVAILABLE; /** - * Performs classification on a GMLImage input, returns an array of - * categorization results where each member in the array is an array of - * TFLClass objects for each classification head. - * This method currently supports inference on only following type of images: - * 1. RGB and RGBA images for GMLImageSourceTypeImage. - * 2. kCVPixelFormatType_32RGBA, kCVPixelFormatType_32BGRA, - * kCVPixelFormatType_24RGB for GMLImageSourceTypePixelBuffer and - * GMLImageSourceTypeSampleBuffer. If you are using AVCaptureSession to setup - * camera and get the frames for inference, you must request for one of these - * supported formats from AVCaptureVideoDataOutput. For a sample code - * snippet, please refer to: - * https://github.com/tensorflow/examples/blob/master/lite/examples/image_classification/ios/ImageClassification/Camera%20Feed/CameraFeedManager.swift#L253 + * Performs classification on the given GMLImage. * - * @param image input to the model. - * @return An NSArray<NSArray<TFLClass *>*> * of classification results. + * @discussion This method currently supports classification of only the + * following types of images: + * 1. RGB and RGBA images for `GMLImageSourceTypeImage`. + * 2. kCVPixelFormatType_32BGRA for `GMLImageSourceTypePixelBuffer` and + * `GMLImageSourceTypeSampleBuffer`. If you are using `AVCaptureSession` to + * setup camera and get the frames for inference, you must request for this + * format from AVCaptureVideoDataOutput. Otherwise your classification results + * will be wrong. + * + * @param image An image to be classified, represented as a `GMLImage`. + * + * @return A TFLClassificationResult with one set of results per image + * classifier head. `nil` if there is an error encountered during + * classification. Please see `TFLClassificationResult` for more details. */ - (nullable TFLClassificationResult*)classifyWithGMLImage:(GMLImage*)image - error:(NSError* _Nullable*) - error - NS_SWIFT_NAME(classify(gmlImage:)); + error:(NSError**)error + NS_SWIFT_NAME(classify(mlImage:)); /** - * Performs classification on a GMLImage input on the pixels in the - * specified bounding box, returns an array of categorization results - * where each member in the array is an array of TFLClass objects for - * each classification head. + * Performs classification on the pixels within the specified region of interest + * of the given `GMLImage`. * - * @param image input to the model. - * @param roi CGRect specifying region of interest in image. + * @discussion This method currently supports inference on only following type + * of images: + * 1. RGB and RGBA images for `GMLImageSourceTypeImage`. + * 2. kCVPixelFormatType_32BGRA for `GMLImageSourceTypePixelBuffer` and + * `GMLImageSourceTypeSampleBuffer`. If you are using `AVCaptureSession` to + * setup camera and get the frames for inference, you must request for this + * format from AVCaptureVideoDataOutput. Otherwise your classification results + * will be wrong. * - * @return An NSArray<NSArray<TFLClass *>*> * of classification results. + * @param image An image to be classified, represented as a `GMLImage`. + * @param roi A CGRect specifying the region of interest within the given + * `GMLImage`, on which classification should be performed. + * + * @return A TFLClassificationResult with one set of results per image + * classifier head. `nil` if there is an error encountered during + * classification. */ - (nullable TFLClassificationResult*)classifyWithGMLImage:(GMLImage*)image regionOfInterest:(CGRect)roi - error:(NSError* _Nullable*) - error - NS_SWIFT_NAME(classify(gmlImage:regionOfInterest:)); + error:(NSError**)error + NS_SWIFT_NAME(classify(mlImage:regionOfInterest:)); - (instancetype)init NS_UNAVAILABLE; -+ (instancetype)new NS_UNAVAILABLE; - @end NS_ASSUME_NONNULL_END
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageClassifier.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageClassifier.m index 9259a8d..79ad474 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageClassifier.m +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageClassifier.m
@@ -40,7 +40,7 @@ return self; } -- (nullable instancetype)initWithModelPath:(NSString*)modelPath { +- (instancetype)initWithModelPath:(NSString*)modelPath { self = [self init]; if (self) { self.baseOptions.modelFile.filePath = modelPath; @@ -66,40 +66,59 @@ + (nullable instancetype)imageClassifierWithOptions: (TFLImageClassifierOptions*)options error:(NSError**)error { - TfLiteImageClassifierOptions cOptions = TfLiteImageClassifierOptionsCreate(); - if (![options.classificationOptions - copyToCOptions:&(cOptions.classification_options) - error:error]) - return nil; - - [options.baseOptions copyToCOptions:&(cOptions.base_options)]; - - TfLiteSupportError *createClassifierError = nil; - TfLiteImageClassifier *imageClassifier = - TfLiteImageClassifierFromOptions(&cOptions, &createClassifierError); - - [options.classificationOptions - deleteCStringArraysOfClassificationOptions:&(cOptions.classification_options)]; - - if (!imageClassifier || ![TFLCommonUtils checkCError:createClassifierError - toError:error]) { - TfLiteSupportErrorDelete(createClassifierError); + if (!options) { + [TFLCommonUtils + createCustomError:error + withCode:TFLSupportErrorCodeInvalidArgumentError + description:@"TFLImageClassifierOptions argument cannot be nil."]; return nil; } - return [[TFLImageClassifier alloc] initWithImageClassifier:imageClassifier]; + TfLiteImageClassifierOptions cOptions = TfLiteImageClassifierOptionsCreate(); + + if (![options.classificationOptions + copyToCOptions:&(cOptions.classification_options) + error:error]) { + [options.classificationOptions deleteAllocatedMemoryOfClassificationOptions: + &(cOptions.classification_options)]; + return nil; + } + + [options.baseOptions copyToCOptions:&(cOptions.base_options)]; + + TfLiteSupportError* cCreateClassifierError = NULL; + TfLiteImageClassifier* cImageClassifier = + TfLiteImageClassifierFromOptions(&cOptions, &cCreateClassifierError); + + [options.classificationOptions deleteAllocatedMemoryOfClassificationOptions: + &(cOptions.classification_options)]; + + // Populate iOS error if TfliteSupportError is not null and afterwards delete + // it. + if (![TFLCommonUtils checkCError:cCreateClassifierError toError:error]) { + TfLiteSupportErrorDelete(cCreateClassifierError); + } + + // Return nil if classifier evaluates to nil. If an error was generted by the + // C layer, it has already been populated to an NSError and deleted before + // returning from the method. + if (!cImageClassifier) { + return nil; + } + + return [[TFLImageClassifier alloc] initWithImageClassifier:cImageClassifier]; } -- (nullable TFLClassificationResult *)classifyWithGMLImage:(GMLImage *)image - error:(NSError *_Nullable *)error { +- (nullable TFLClassificationResult*)classifyWithGMLImage:(GMLImage*)image + error:(NSError**)error { return [self classifyWithGMLImage:image regionOfInterest:CGRectMake(0, 0, image.width, image.height) error:error]; } -- (nullable TFLClassificationResult *)classifyWithGMLImage:(GMLImage *)image - regionOfInterest:(CGRect)roi - error:(NSError *_Nullable *)error { +- (nullable TFLClassificationResult*)classifyWithGMLImage:(GMLImage*)image + regionOfInterest:(CGRect)roi + error:(NSError**)error { if (!image) { [TFLCommonUtils createCustomError:error withCode:TFLSupportErrorCodeInvalidArgumentError @@ -118,19 +137,25 @@ .width = roi.size.width, .height = roi.size.height}; - TfLiteSupportError *classifyError = nil; + TfLiteSupportError* classifyError = NULL; TfLiteClassificationResult *cClassificationResult = TfLiteImageClassifierClassifyWithRoi( _imageClassifier, cFrameBuffer, &boundingBox, &classifyError); free(cFrameBuffer->buffer); - cFrameBuffer->buffer = nil; + cFrameBuffer->buffer = NULL; free(cFrameBuffer); - cFrameBuffer = nil; + cFrameBuffer = NULL; - if (!cClassificationResult || ![TFLCommonUtils checkCError:classifyError - toError:error]) { + // Populate iOS error if C Error is not null and afterwards delete it. + if (![TFLCommonUtils checkCError:classifyError toError:error]) { TfLiteSupportErrorDelete(classifyError); + } + + // Return nil if C result evaluates to nil. If an error was generted by the C + // layer, it has already been populated to an NSError and deleted before + // returning from the method. + if (!cClassificationResult) { return nil; }
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageSegmenter.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageSegmenter.h index eb0d5658..234e10d 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageSegmenter.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageSegmenter.h
@@ -20,103 +20,118 @@ NS_ASSUME_NONNULL_BEGIN /** - * Specifies the type of output segmentation mask to be returned as a result - * of the image segmentation operation. This allows specifying the type of - * post-processing to perform on the raw model results - * - * @seealso TfLiteSegmentationResult for more. + * Specifies the type of the output segmentation mask to be returned as the + * result of the image segmentation operation. This directs the + * `TFLImageSegmenter` to choose the type of post-processing to be performed on + * the raw model results. */ typedef NS_ENUM(NSUInteger, TFLOutputType) { /** Unspecified output type. */ - TFLUnspecifiedOutputType, + TFLOutputTypeUnspecified, /** * Gives a single output mask where each pixel represents the class which * the pixel in the original image was predicted to belong to. */ - TFLCategoryMaskOutputType, + TFLOutputTypeCategoryMask, /** * Gives a list of output masks where, for each mask, each pixel represents * the prediction confidence, usually in the [0, 1] range. */ - TFLConfidenceMasksOutputType, + TFLOutputTypeConfidenceMasks, -}; +} NS_SWIFT_NAME(OutputType); /** - * Options to configure TFLImageSegmenter. + * Options to configure `TFLImageSegmenter`. */ +NS_SWIFT_NAME(ImageSegmenterOptions) @interface TFLImageSegmenterOptions : NSObject /** * Base options that is used for creation of any type of task. - * @seealso TFLBaseOptions + * @discussion Please see `TFLBaseOptions` for more details. */ @property(nonatomic, copy) TFLBaseOptions* baseOptions; /** * Specifies the type of output segmentation mask to be returned as a result * of the image segmentation operation. - * @seealso TFLOutputType */ -@property(nonatomic, assign) TFLOutputType outputType; +@property(nonatomic) TFLOutputType outputType; -/** Display names local for display names*/ +/** + * Display names local for display names + */ @property(nonatomic, copy) NSString* displayNamesLocale; /** - * Initializes TFLImageSegmenterOptions with the model path set to the specified - * path to a model file. - * @description The external model file, must be a single standalone TFLite + * Initializes a new `TFLImageSegmenterOptions` with the absolute path to the + * model file stored locally on the device, set to the given the model path. + * . + * @discussion The external model file, must be a single standalone TFLite * file. It could be packed with TFLite Model Metadata[1] and associated files * if exist. Fail to provide the necessary metadata and associated files might * result in errors. Check the * [documentation](https://www.tensorflow.org/lite/convert/metadata) for each * task about the specific requirement. * - * @param modelPath Path to a TFLite model file. + * @param modelPath An absolute path to a TensorFlow Lite model file stored + * locally on the device. * - * @return An instance of TFLImageSegmenterOptions set to the specified - * modelPath. + * @return An instance of `TFLImageSegmenterOptions` initialized to the given + * model path. */ -- (nullable instancetype)initWithModelPath:(nonnull NSString*)modelPath; +- (instancetype)initWithModelPath:(NSString*)modelPath; @end +NS_SWIFT_NAME(ImageSegmenter) @interface TFLImageSegmenter : NSObject /** - * Creates TFLImageSegmenter from a model file and specified options . + * Creates a new instance of `TFLImageSegmenter` from the given + * `TFLImageSegmenterOptions`. * - * @param options TFLImageSegmenterOptions instance with the necessary - * properties set. + * @param options The options to use for configuring the `TFLImageSegmenter`. + * @param error An optional error parameter populated when there is an error in + * initializing the image segmenter. * - * @return A TFLImageSegmenter instance. + * @return A new instance of `TFLImageSegmenter` with the given options. `nil` + * if there is an error in initializing the image segmenter. */ + (nullable instancetype)imageSegmenterWithOptions: (nonnull TFLImageSegmenterOptions*)options error:(NSError**)error - NS_SWIFT_NAME(imageSegmenter(options:)); - -/** - * Performs image segmentation on a GMLImage input, returns the segmentation - * results. - * - * @param image input to the model. - * - * @return Segmentation Result of type TFLSegmentationResult holds the - * segmentation masks returned by the image segmentation task. - */ -- (nullable TFLSegmentationResult*)segmentWithGMLImage:(GMLImage*)image - error: - (NSError* _Nullable*)error - NS_SWIFT_NAME(segment(gmlImage:)); - -- (instancetype)init NS_UNAVAILABLE; + NS_SWIFT_NAME(segmenter(options:)); + (instancetype)new NS_UNAVAILABLE; +/** + * Performs segmentation on the given GMLImage. + * + * @discussion This method currently supports segmentation of only the following + * types of images: + * 1. RGB and RGBA images for `GMLImageSourceTypeImage`. + * 2. kCVPixelFormatType_32BGRA for `GMLImageSourceTypePixelBuffer` and + * `GMLImageSourceTypeSampleBuffer`. If you are using `AVCaptureSession` to + * setup camera and get the frames for inference, you must request for this + * format from AVCaptureVideoDataOutput. Otherwise your segmentation results + * will be wrong. + * + * @param image An image to be segmented, represented as a `GMLImage`. + * + * @return A TFLSegmentationResult that holds the segmentation masks returned by + * the image segmentation task. `nil` if there is an error encountered during + * segmentation. Please see `TFLSegmentationResult` for more details. + */ +- (nullable TFLSegmentationResult*)segmentWithGMLImage:(GMLImage*)image + error:(NSError**)error + NS_SWIFT_NAME(segment(mlImage:)); + +- (instancetype)init NS_UNAVAILABLE; + @end NS_ASSUME_NONNULL_END
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageSegmenter.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageSegmenter.m index 8f22beeb..7b7f3211 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageSegmenter.m +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLImageSegmenter.m
@@ -13,6 +13,7 @@ limitations under the License. ==============================================================================*/ #import "tensorflow_lite_support/ios/task/vision/sources/TFLImageSegmenter.h" +#import "tensorflow_lite_support/ios/sources/TFLCommon.h" #import "tensorflow_lite_support/ios/sources/TFLCommonUtils.h" #import "tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions+Helpers.h" #import "tensorflow_lite_support/ios/task/processor/sources/TFLSegmentationResult+Helpers.h" @@ -29,11 +30,12 @@ self = [super init]; if (self) { self.baseOptions = [[TFLBaseOptions alloc] init]; + self.outputType = TFLOutputTypeCategoryMask; } return self; } -- (nullable instancetype)initWithModelPath:(nonnull NSString*)modelPath { +- (instancetype)initWithModelPath:(NSString*)modelPath { self = [self init]; if (self) { self.baseOptions.modelFile.filePath = modelPath; @@ -66,32 +68,62 @@ TfLiteImageSegmenterOptions cOptions = TfLiteImageSegmenterOptionsCreate(); [options.baseOptions copyToCOptions:&(cOptions.base_options)]; + cOptions.output_type = (TfLiteImageSegmenterOutputType)options.outputType; - TfLiteSupportError* createImageSegmenterError = nil; - TfLiteImageSegmenter* imageSegmenter = - TfLiteImageSegmenterFromOptions(&cOptions, &createImageSegmenterError); - - if (!imageSegmenter || ![TFLCommonUtils checkCError:createImageSegmenterError - toError:error]) { - TfLiteSupportErrorDelete(createImageSegmenterError); - return nil; + if (options.displayNamesLocale) { + if (options.displayNamesLocale.UTF8String) { + cOptions.display_names_locale = + strdup(options.displayNamesLocale.UTF8String); + if (!cOptions.display_names_locale) { + exit(-1); // Memory Allocation Failed. + } + } else { + [TFLCommonUtils + createCustomError:error + withCode:TFLSupportErrorCodeInvalidArgumentError + description:@"Could not convert (NSString *) to (char *)."]; + return nil; + } } - return [[TFLImageSegmenter alloc] initWithImageSegmenter:imageSegmenter]; + TfLiteSupportError* cCreateImageSegmenterError = nil; + TfLiteImageSegmenter* cImageSegmenter = + TfLiteImageSegmenterFromOptions(&cOptions, &cCreateImageSegmenterError); + + // Freeing memory of allocated string. + free(cOptions.display_names_locale); + + if (![TFLCommonUtils checkCError:cCreateImageSegmenterError toError:error]) { + TfLiteSupportErrorDelete(cCreateImageSegmenterError); + } + + // Return nil if C object detector evaluates to nil. If an error was generted + // by the C layer, it has already been populated to an NSError and deleted + // before returning from the method. + if (!cImageSegmenter) { + return nil; + } + return [[TFLImageSegmenter alloc] initWithImageSegmenter:cImageSegmenter]; } - (nullable TFLSegmentationResult*)segmentWithGMLImage:(GMLImage*)image - error:(NSError* _Nullable*) - error { + error:(NSError**)error { + if (!image) { + [TFLCommonUtils createCustomError:error + withCode:TFLSupportErrorCodeInvalidArgumentError + description:@"GMLImage argument cannot be nil."]; + return nil; + } + TfLiteFrameBuffer* cFrameBuffer = [image cFrameBufferWithError:error]; if (!cFrameBuffer) { return nil; } - TfLiteSupportError* segmentError = nil; - TfLiteSegmentationResult* cSegmentationResult = - TfLiteImageSegmenterSegment(_imageSegmenter, cFrameBuffer, &segmentError); + TfLiteSupportError* cSegmentError = nil; + TfLiteSegmentationResult* cSegmentationResult = TfLiteImageSegmenterSegment( + _imageSegmenter, cFrameBuffer, &cSegmentError); free(cFrameBuffer->buffer); cFrameBuffer->buffer = nil; @@ -99,9 +131,15 @@ free(cFrameBuffer); cFrameBuffer = nil; - if (!cSegmentationResult || ![TFLCommonUtils checkCError:segmentError - toError:error]) { - TfLiteSupportErrorDelete(segmentError); + // Populate iOS error if C Error is not null and afterwards delete it. + if (![TFLCommonUtils checkCError:cSegmentError toError:error]) { + TfLiteSupportErrorDelete(cSegmentError); + } + + // Return nil if C result evaluates to nil. If an error was generted by the C + // layer, it has already been populated to an NSError and deleted before + // returning from the method. + if (!cSegmentationResult) { return nil; }
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLObjectDetector.h b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLObjectDetector.h index 38b3b0a..db76c90c 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLObjectDetector.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLObjectDetector.h
@@ -21,72 +21,88 @@ NS_ASSUME_NONNULL_BEGIN /** - * Options to configure TFLObjectDetector. + * Options to configure `TFLObjectDetector`. */ +NS_SWIFT_NAME(ObjectDetectorOptions) @interface TFLObjectDetectorOptions : NSObject /** * Base options that is used for creation of any type of task. - * @seealso TFLBaseOptions + * @discussion Please see `TFLBaseOptions` for more details. */ @property(nonatomic, copy) TFLBaseOptions* baseOptions; /** * Options that configure the display and filtering of results. - * @seealso TFLClassificationOptions + * @discussion Please see `TFLClassificationOptions` for more details. */ @property(nonatomic, copy) TFLClassificationOptions* classificationOptions; /** - * Initializes TFLObjectDetectorOptions with the model path set to the specified - * path to a model file. - * @description The external model file, must be a single standalone TFLite - * file. It could be packed with TFLite Model Metadata[1] and associated files - * if exist. Fail to provide the necessary metadata and associated files might + * Initializes a new `TFLObjectDetectorOptions` with the absolute path to the + * model file stored locally on the device, set to the given the model path. + * + * @discussion The external model file, must be a single standalone TFLite file. + * It could be packed with TFLite Model Metadata[1] and associated files if + * exist. Fail to provide the necessary metadata and associated files might * result in errors. Check the [documentation] * (https://www.tensorflow.org/lite/convert/metadata) for each task about the * specific requirement. * - * @param modelPath Path to a TFLite model file. - * @return An instance of TFLObjectDetectorOptions set to the specified - * modelPath. + * @param modelPath An absolute path to a TensorFlow Lite model file stored + * locally on the device. + * @return An instance of `TFLObjectDetectorOptions` initialized to the given + * model path. */ -- (nullable instancetype)initWithModelPath:(nonnull NSString*)modelPath; +- (instancetype)initWithModelPath:(NSString*)modelPath; @end +NS_SWIFT_NAME(ObjectDetector) @interface TFLObjectDetector : NSObject /** - * Creates TFLObjectDetector from a model file and specified options . + * Creates a new instance of `TFLObjectDetector` from the given + * `TFLObjectDetectorOptions`. * - * @param options TFLObjectDetectorOptions instance with the necessary - * properties set. + * @param options The options to use for configuring the `TFLObjectDetector`. + * @param error An optional error parameter populated when there is an error in + * initializing the object detector. * - * @return A TFLObjectDetector instance. + * @return A new instance of `TFLObjectDetector` with the given options. `nil` + * if there is an error in initializing the object detector. */ + (nullable instancetype)objectDetectorWithOptions: - (nonnull TFLObjectDetectorOptions*)options + (TFLObjectDetectorOptions*)options error:(NSError**)error - NS_SWIFT_NAME(objectDetector(options:)); - -/** - * Performs object detection on a GMLImage input, returns the detected objects - * in the image. - * - * @param image input to the model. - * @return Detection Result of type TFLDetectionResult an array of - * detected objeects where each detected object has a bounding box and an array - * of TFLCategory holding the predicted classes for the detected object. - */ -- (nullable TFLDetectionResult*)detectWithGMLImage:(GMLImage*)image - error:(NSError* _Nullable*)error - NS_SWIFT_NAME(detect(gmlImage:)); - -- (instancetype)init NS_UNAVAILABLE; + NS_SWIFT_NAME(detector(options:)); + (instancetype)new NS_UNAVAILABLE; +/** + * Performs object detection on the given GMLImage. + * @discussion This method currently supports object detection on only the + * following types of images: + * 1. RGB and RGBA images for `GMLImageSourceTypeImage`. + * 2. `kCVPixelFormatType_32BGRA` for `GMLImageSourceTypePixelBuffer` and + * `GMLImageSourceTypeSampleBuffer`. If you are using `AVCaptureSession` to + * setup camera and get the frames for inference, you must request for this + * format from AVCaptureVideoDataOutput. Otherwise your object detection results + * will be wrong. + * + * @param image An image on which object detection is to be performed, + * represented as a `GMLImage`. + * + * @return A `TFLDetectionResult` holding an array of TFLDetection objects, each + * having a bounding box specifying the region the were detected in and an array + * of predicted classes. Please see `TFLDetectionResult` for more details. + */ +- (nullable TFLDetectionResult*)detectWithGMLImage:(GMLImage*)image + error:(NSError**)error + NS_SWIFT_NAME(detect(mlImage:)); + +- (instancetype)init NS_UNAVAILABLE; + @end NS_ASSUME_NONNULL_END
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLObjectDetector.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLObjectDetector.m index 1e47e254..def2e5b 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLObjectDetector.m +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/sources/TFLObjectDetector.m
@@ -13,6 +13,7 @@ limitations under the License. ==============================================================================*/ #import "tensorflow_lite_support/ios/task/vision/sources/TFLObjectDetector.h" +#import "tensorflow_lite_support/ios/sources/TFLCommon.h" #import "tensorflow_lite_support/ios/sources/TFLCommonUtils.h" #import "tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions+Helpers.h" #import "tensorflow_lite_support/ios/task/processor/sources/TFLClassificationOptions+Helpers.h" @@ -39,7 +40,7 @@ return self; } -- (nullable instancetype)initWithModelPath:(nonnull NSString *)modelPath { +- (instancetype)initWithModelPath:(NSString*)modelPath { self = [self init]; if (self) { self.baseOptions.modelFile.filePath = modelPath; @@ -62,43 +63,70 @@ return self; } -+ (nullable instancetype)objectDetectorWithOptions:(nonnull TFLObjectDetectorOptions *)options - error:(NSError **)error { - TfLiteObjectDetectorOptions cOptions = TfLiteObjectDetectorOptionsCreate(); - if (![options.classificationOptions - copyToCOptions:&(cOptions.classification_options) - error:error]) - return nil; - - [options.baseOptions copyToCOptions:&(cOptions.base_options)]; - - TfLiteSupportError *createObjectDetectorError = nil; - TfLiteObjectDetector *objectDetector = - TfLiteObjectDetectorFromOptions(&cOptions, &createObjectDetectorError); - - [options.classificationOptions - deleteCStringArraysOfClassificationOptions:&(cOptions.classification_options)]; - - if (!objectDetector || ![TFLCommonUtils checkCError:createObjectDetectorError - toError:error]) { - TfLiteSupportErrorDelete(createObjectDetectorError); ++ (nullable instancetype)objectDetectorWithOptions: + (TFLObjectDetectorOptions*)options + error:(NSError**)error { + if (!options) { + [TFLCommonUtils + createCustomError:error + withCode:TFLSupportErrorCodeInvalidArgumentError + description:@"TFLObjectDetectorOptions argument cannot be nil."]; return nil; } - return [[TFLObjectDetector alloc] initWithObjectDetector:objectDetector]; + TfLiteObjectDetectorOptions cOptions = TfLiteObjectDetectorOptionsCreate(); + if (![options.classificationOptions + copyToCOptions:&(cOptions.classification_options) + error:error]) { + // Deallocating any allocated memory on failure. + [options.classificationOptions deleteAllocatedMemoryOfClassificationOptions: + &(cOptions.classification_options)]; + return nil; + } + + [options.baseOptions copyToCOptions:&(cOptions.base_options)]; + + TfLiteSupportError* cCreateObjectDetectorError = nil; + TfLiteObjectDetector* cObjectDetector = + TfLiteObjectDetectorFromOptions(&cOptions, &cCreateObjectDetectorError); + + [options.classificationOptions deleteAllocatedMemoryOfClassificationOptions: + &(cOptions.classification_options)]; + + // Populate iOS error if TfliteSupportError is not null and afterwards delete + // it. + if (![TFLCommonUtils checkCError:cCreateObjectDetectorError toError:error]) { + TfLiteSupportErrorDelete(cCreateObjectDetectorError); + } + + // Return nil if C object detector evaluates to nil. If an error was generted + // by the C layer, it has already been populated to an NSError and deleted + // before returning from the method. + if (!cObjectDetector) { + return nil; + } + + return [[TFLObjectDetector alloc] initWithObjectDetector:cObjectDetector]; } -- (nullable TFLDetectionResult *)detectWithGMLImage:(GMLImage *)image - error:(NSError *_Nullable *)error { +- (nullable TFLDetectionResult*)detectWithGMLImage:(GMLImage*)image + error:(NSError**)error { + if (!image) { + [TFLCommonUtils createCustomError:error + withCode:TFLSupportErrorCodeInvalidArgumentError + description:@"GMLImage argument cannot be nil."]; + return nil; + } + TfLiteFrameBuffer* cFrameBuffer = [image cFrameBufferWithError:error]; if (!cFrameBuffer) { return nil; } - TfLiteSupportError *detectError = nil; - TfLiteDetectionResult *cDetectionResult = - TfLiteObjectDetectorDetect(_objectDetector, cFrameBuffer, &detectError); + TfLiteSupportError* cDetectError = nil; + TfLiteDetectionResult* cDetectionResult = + TfLiteObjectDetectorDetect(_objectDetector, cFrameBuffer, &cDetectError); free(cFrameBuffer->buffer); cFrameBuffer->buffer = nil; @@ -106,9 +134,15 @@ free(cFrameBuffer); cFrameBuffer = nil; - if (!cDetectionResult || ![TFLCommonUtils checkCError:detectError - toError:error]) { - TfLiteSupportErrorDelete(detectError); + // Populate iOS error if C Error is not null and afterwards delete it. + if (![TFLCommonUtils checkCError:cDetectError toError:error]) { + TfLiteSupportErrorDelete(cDetectError); + } + + // Return nil if C result evaluates to nil. If an error was generted by the C + // layer, it has already been populated to an NSError and deleted before + // returning from the method. + if (!cDetectionResult) { return nil; }
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/utils/sources/GMLImage+Utils.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/utils/sources/GMLImage+Utils.m index 4339ade..532f75ef 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/utils/sources/GMLImage+Utils.m +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/task/vision/utils/sources/GMLImage+Utils.m
@@ -32,13 +32,10 @@ buffer:(uint8_t*)buffer error:(NSError**)error; -+ (uint8_t* _Nullable) - convertBGRAtoRGBforPixelBufferBaseAddress:(CVPixelBufferRef)pixelBuffer - error:(NSError**)error; - + (TfLiteFrameBuffer*)cFramebufferFromCVPixelBuffer: (CVPixelBufferRef)pixelBuffer error:(NSError**)error; + @end @interface UIImage (RawPixelDataUtils) @@ -54,6 +51,10 @@ (enum TfLiteFrameBufferFormat)frameBufferFormat buffer:(uint8_t*)buffer error:(NSError**)error { + if (!buffer) { + return NULL; + } + TfLiteFrameBuffer* cFrameBuffer = [TFLCommonUtils mallocWithSize:sizeof(TfLiteFrameBuffer) error:error]; @@ -67,86 +68,14 @@ return cFrameBuffer; } -+ (TfLiteFrameBuffer*)cFramebufferFromCVPixelBuffer: - (CVPixelBufferRef)pixelBuffer - error:(NSError**)error { - uint8_t* buffer = NULL; - enum TfLiteFrameBufferFormat cPixelFormat = kRGB; - - CVPixelBufferLockBaseAddress(pixelBuffer, 0); - OSType pixelBufferFormat = CVPixelBufferGetPixelFormatType(pixelBuffer); - - switch (pixelBufferFormat) { - case kCVPixelFormatType_24RGB: { - cPixelFormat = kRGB; - buffer = - [TFLCVPixelBufferUtils copyPixelBufferDataForInference:pixelBuffer - error:error]; - break; - } - case kCVPixelFormatType_32RGBA: { - cPixelFormat = kRGBA; - buffer = - [TFLCVPixelBufferUtils copyPixelBufferDataForInference:pixelBuffer - error:error]; - break; - } - case kCVPixelFormatType_32BGRA: { - cPixelFormat = kRGB; - buffer = [TFLCVPixelBufferUtils - convertBGRAtoRGBforPixelBufferBaseAddress:pixelBuffer - error:error]; - break; - } - default: { - [TFLCommonUtils - createCustomError:error - withCode:TFLSupportErrorCodeInvalidArgumentError - description: - @"Unsupported pixel format for CVPixelBuffer. Supported " - @"pixel format types are kCVPixelFormatType_32RGBA, " - @"kCVPixelFormatType_32BGRA, kCVPixelFormatType_24RGB"]; - } - } - - CVPixelBufferUnlockBaseAddress(pixelBuffer, 0); - - if (!buffer) { - return NULL; - } - - return [self cFrameBufferWithWidth:(int)CVPixelBufferGetWidth(pixelBuffer) - height:(int)CVPixelBufferGetHeight(pixelBuffer) - frameBufferFormat:cPixelFormat - buffer:buffer - error:error]; -} - -+ (UInt8*)copyPixelBufferDataForInference:(CVPixelBufferRef)pixelBuffer - error:(NSError**)error { - size_t height = CVPixelBufferGetHeight(pixelBuffer); - size_t stride = CVPixelBufferGetBytesPerRow(pixelBuffer); - UInt8* buffer = [TFLCommonUtils mallocWithSize:sizeof(UInt8) * height * stride - error:error]; - - if (buffer) - memcpy(buffer, CVPixelBufferGetBaseAddress(pixelBuffer), height * stride); - - return buffer; -} - -+ (uint8_t*)convertBGRAtoRGBforPixelBufferBaseAddress: - (CVPixelBufferRef)pixelBuffer - error:(NSError**)error { - size_t width = CVPixelBufferGetWidth(pixelBuffer); - size_t height = CVPixelBufferGetHeight(pixelBuffer); - size_t stride = CVPixelBufferGetBytesPerRow(pixelBuffer); - - int destinationChannelCount = 3; - size_t destinationBytesPerRow = destinationChannelCount * width; - - uint8_t* pixelBufferBaseAddress = - (uint8_t*)CVPixelBufferGetBaseAddress(pixelBuffer); ++ (uint8_t*)createRGBImageDatafromImageData:(uint8_t*)data + withWidth:(size_t)width + height:(size_t)height + stride:(size_t)stride + pixelBufferFormat:(OSType)pixelBufferFormatType + error:(NSError**)error { + NSInteger destinationChannelCount = 3; + size_t destinationBytesPerRow = width * destinationChannelCount; uint8_t* destPixelBufferAddress = [TFLCommonUtils mallocWithSize:sizeof(uint8_t) * height * destinationBytesPerRow @@ -156,30 +85,107 @@ return NULL; } - vImage_Buffer srcBuffer = {.data = pixelBufferBaseAddress, - .height = height, - .width = width, + vImage_Buffer srcBuffer = {.data = data, + .height = (vImagePixelCount)height, + .width = (vImagePixelCount)width, .rowBytes = stride}; vImage_Buffer destBuffer = {.data = destPixelBufferAddress, - .height = height, - .width = width, + .height = (vImagePixelCount)height, + .width = (vImagePixelCount)width, .rowBytes = destinationBytesPerRow}; vImage_Error convertError = kvImageNoError; - convertError = - vImageConvert_BGRA8888toRGB888(&srcBuffer, &destBuffer, kvImageNoFlags); + + switch (pixelBufferFormatType) { + case kCVPixelFormatType_32RGBA: { + convertError = vImageConvert_RGBA8888toRGB888(&srcBuffer, &destBuffer, + kvImageNoFlags); + break; + } + case kCVPixelFormatType_32BGRA: { + convertError = vImageConvert_BGRA8888toRGB888(&srcBuffer, &destBuffer, + kvImageNoFlags); + break; + } + default: { + [TFLCommonUtils + createCustomError:error + withCode:TFLSupportErrorCodeInvalidArgumentError + description: + @"Invalid source pixel buffer format. Expecting one of " + @"kCVPixelFormatType_32RGBA, kCVPixelFormatType_32BGRA, " + @"kCVPixelFormatType_32ARGB"]; + + free(destPixelBufferAddress); + return NULL; + } + } if (convertError != kvImageNoError) { [TFLCommonUtils createCustomError:error withCode:TFLSupportErrorCodeImageProcessingError description:@"Image format conversion failed."]; + + free(destPixelBufferAddress); return NULL; } return destPixelBufferAddress; } ++ (uint8_t*)createRGBImageDatafromCVPixelBuffer:(CVPixelBufferRef)pixelBuffer + error:(NSError**)error { + CVPixelBufferLockBaseAddress(pixelBuffer, 0); + + uint8_t* rgbData = [TFLCVPixelBufferUtils + createRGBImageDatafromImageData:CVPixelBufferGetBaseAddress(pixelBuffer) + withWidth:CVPixelBufferGetWidth(pixelBuffer) + height:CVPixelBufferGetHeight(pixelBuffer) + stride:CVPixelBufferGetBytesPerRow(pixelBuffer) + pixelBufferFormat:CVPixelBufferGetPixelFormatType( + pixelBuffer) + error:error]; + + CVPixelBufferUnlockBaseAddress(pixelBuffer, 0); + + return rgbData; +} + ++ (TfLiteFrameBuffer*)cFramebufferFromCVPixelBuffer: + (CVPixelBufferRef)pixelBuffer + error:(NSError**)error { + uint8_t* buffer = NULL; + enum TfLiteFrameBufferFormat cPixelFormat = kRGB; + + OSType pixelBufferFormat = CVPixelBufferGetPixelFormatType(pixelBuffer); + + switch (pixelBufferFormat) { + case kCVPixelFormatType_32BGRA: { + cPixelFormat = kRGB; + + buffer = + [TFLCVPixelBufferUtils createRGBImageDatafromCVPixelBuffer:pixelBuffer + error:error]; + break; + } + default: { + [TFLCommonUtils + createCustomError:error + withCode:TFLSupportErrorCodeInvalidArgumentError + description: + @"Unsupported pixel format for CVPixelBuffer. Supported " + @"pixel format types are kCVPixelFormatType_32BGRA"]; + } + } + + return [self cFrameBufferWithWidth:(int)CVPixelBufferGetWidth(pixelBuffer) + height:(int)CVPixelBufferGetHeight(pixelBuffer) + frameBufferFormat:cPixelFormat + buffer:buffer + error:error]; +} + @end @implementation UIImage (RawPixelDataUtils) @@ -230,37 +236,48 @@ + (UInt8* _Nullable)pixelDataFromCGImage:(CGImageRef)cgImage error:(NSError**)error { - long width = CGImageGetWidth(cgImage); - long height = CGImageGetHeight(cgImage); + size_t width = CGImageGetWidth(cgImage); + size_t height = CGImageGetHeight(cgImage); NSInteger bitsPerComponent = 8; NSInteger channelCount = 4; UInt8* buffer_to_return = NULL; CGColorSpaceRef colorSpace = CGColorSpaceCreateDeviceRGB(); + size_t bytesPerRow = channelCount * width; // iOS infers bytesPerRow if it is set to 0. // See // https://developer.apple.com/documentation/coregraphics/1455939-cgbitmapcontextcreate // But for segmentation test image, this was not the case. // Hence setting it to the value of channelCount*width. - CGContextRef context = CGBitmapContextCreate( - nil, width, height, bitsPerComponent, channelCount * width, colorSpace, - kCGImageAlphaNoneSkipLast | kCGBitmapByteOrder32Big); + // kCGImageAlphaNoneSkipLast specifies that Alpha will always be next to B. + // kCGBitmapByteOrder32Big specifies that R will be stored before B. + // In combination they signify a pixelFormat of kCVPixelFormatType32RGBA. + CGBitmapInfo bitMapinfoFor32RGBA = + kCGImageAlphaNoneSkipLast | kCGBitmapByteOrder32Big; + CGContextRef context = + CGBitmapContextCreate(nil, width, height, bitsPerComponent, bytesPerRow, + colorSpace, bitMapinfoFor32RGBA); if (context) { CGContextDrawImage(context, CGRectMake(0, 0, width, height), cgImage); - buffer_to_return = [UIImage - populateRGBBufferFromSourceRGBABuffer:CGBitmapContextGetData(context) - width:width - height:height]; - CGContextRelease(context); - } + uint8_t* srcData = CGBitmapContextGetData(context); - if (buffer_to_return == NULL) { - [TFLCommonUtils createCustomError:error - withCode:TFLSupportErrorCodeImageProcessingError - description:@"Image format conversion failed."]; + if (srcData) { + // We have drawn the image as an RGBA image with 8 bitsPerComponent and + // hence can safely input a pixel format of type kCVPixelFormatType_32RGBA + // for conversion by vImage. + buffer_to_return = [TFLCVPixelBufferUtils + createRGBImageDatafromImageData:srcData + withWidth:width + height:height + stride:bytesPerRow + pixelBufferFormat:kCVPixelFormatType_32RGBA + error:error]; + } + + CGContextRelease(context); } CGColorSpaceRelease(colorSpace); @@ -268,41 +285,10 @@ return buffer_to_return; } -+ (nullable UInt8*)populateRGBBufferFromSourceRGBABuffer:(UInt8*)buffer - width:(size_t)width - height:(size_t)height { - if (!buffer) - return NULL; - - int sourceChannelCount = 4; - int destChannelCount = 3; - - UInt8* buffer_to_return = [TFLCommonUtils - mallocWithSize:sizeof(UInt8) * height * destChannelCount * width - error:nil]; - if (!buffer_to_return) { - return NULL; - } - for (int col = 0; col < width; col++) { - for (int row = 0; row < height; row++) { - long offset = sourceChannelCount * (row * width + col); - long rgbOffset = destChannelCount * (row * width + col); - buffer_to_return[rgbOffset] = buffer[offset]; - buffer_to_return[rgbOffset + 1] = buffer[offset + 1]; - buffer_to_return[rgbOffset + 2] = buffer[offset + 2]; - } - } - return buffer_to_return; -} - - (TfLiteFrameBuffer*)frameBufferFromCGImage:(CGImageRef)cgImage error:(NSError**)error { UInt8* buffer = [UIImage pixelDataFromCGImage:cgImage error:error]; - if (buffer == NULL) { - return NULL; - } - return [TFLCVPixelBufferUtils cFrameBufferWithWidth:(int)CGImageGetWidth(cgImage) height:(int)CGImageGetHeight(cgImage) @@ -317,13 +303,15 @@ int width = 0; int height = 0; + if (ciImage.pixelBuffer) { - buffer = [TFLCVPixelBufferUtils - convertBGRAtoRGBforPixelBufferBaseAddress:ciImage.pixelBuffer - error:error]; width = (int)CVPixelBufferGetWidth(ciImage.pixelBuffer); height = (int)CVPixelBufferGetHeight(ciImage.pixelBuffer); + buffer = [TFLCVPixelBufferUtils + createRGBImageDatafromCVPixelBuffer:ciImage.pixelBuffer + error:error]; + } else if (ciImage.CGImage) { buffer = [UIImage pixelDataFromCGImage:ciImage.CGImage error:error]; width = (int)CGImageGetWidth(ciImage.CGImage); @@ -336,10 +324,6 @@ @"CIImage should have CGImage or CVPixelBuffer info."]; } - if (buffer == NULL) { - return NULL; - } - return [TFLCVPixelBufferUtils cFrameBufferWithWidth:width height:height frameBufferFormat:kRGBA
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/audio/core/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/audio/core/BUILD new file mode 100644 index 0000000..53e8a54 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/audio/core/BUILD
@@ -0,0 +1,30 @@ +load("@org_tensorflow//tensorflow/lite/ios:ios.bzl", "TFL_DEFAULT_TAGS", "TFL_DISABLED_SANITIZER_TAGS", "TFL_MINIMUM_OS_VERSION") +load("@build_bazel_rules_apple//apple:ios.bzl", "ios_unit_test") +load("@org_tensorflow//tensorflow/lite:special_rules.bzl", "tflite_ios_lab_runner") + +package( + default_visibility = ["//visibility:private"], + licenses = ["notice"], # Apache 2.0 +) + +objc_library( + name = "TFLRingBufferObjcTestLibrary", + testonly = 1, + srcs = ["TFLRingBufferTests.m"], + tags = TFL_DEFAULT_TAGS, + deps = [ + "//tensorflow_lite_support/ios:TFLCommon", + "//tensorflow_lite_support/ios/task/audio/core:TFLRingBuffer", + "//third_party/apple_frameworks:XCTest", + ], +) + +ios_unit_test( + name = "TFLRingBufferObjcTest", + minimum_os_version = TFL_MINIMUM_OS_VERSION, + runner = tflite_ios_lab_runner("IOS_LATEST"), + tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS, + deps = [ + ":TFLRingBufferObjcTestLibrary", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/audio/core/TFLRingBufferTests.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/audio/core/TFLRingBufferTests.m new file mode 100644 index 0000000..cd389b9 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/audio/core/TFLRingBufferTests.m
@@ -0,0 +1,329 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import <XCTest/XCTest.h> + +#import "tensorflow_lite_support/ios/sources/TFLCommon.h" +#import "tensorflow_lite_support/ios/task/audio/core/sources/TFLRingBuffer.h" + +#define VerifyError(error, expectedErrorDomain, expectedErrorCode, \ + expectedLocalizedDescription) \ + XCTAssertEqual(error.domain, expectedErrorDomain); \ + XCTAssertEqual(error.code, expectedErrorCode); \ + XCTAssertEqualObjects(error.localizedDescription, \ + expectedLocalizedDescription); + +NS_ASSUME_NONNULL_BEGIN + +@interface TFLRingBufferTests : XCTestCase +@end + +@implementation TFLRingBufferTests + +- (void)testLoadSucceedsWithFullLengthBuffer { + NSInteger inDataLength = 5; + float inData[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}; + + TFLFloatBuffer* inBuffer = + [[TFLFloatBuffer alloc] initWithData:&(inData[0]) size:inDataLength]; + + NSInteger bufferSize = 5; + TFLRingBuffer* ringBuffer = + [[TFLRingBuffer alloc] initWithBufferSize:bufferSize]; + + XCTAssertTrue([ringBuffer loadBuffer:inBuffer + offset:0 + size:inDataLength + error:nil]); + // State after load: [1.0, 2.0, 3.0, 4.0, 5.0] + + TFLFloatBuffer* outBuffer = ringBuffer.floatBuffer; + XCTAssertNotNil(outBuffer); + XCTAssertEqual(outBuffer.size, bufferSize); + + float expectedData[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}; + + for (int i = 0; i < inDataLength; i++) { + XCTAssertEqual(outBuffer.data[i], inData[i]); + } +} + +- (void)testLoadSucceedsWithPartialLengthBuffer { + NSInteger inDataSize = 3; + float inData[] = {1.0f, 2.0f, 3.0f}; + TFLFloatBuffer* inBuffer = + [[TFLFloatBuffer alloc] initWithData:&(inData[0]) size:inDataSize]; + + NSInteger bufferSize = 5; + TFLRingBuffer* ringBuffer = + [[TFLRingBuffer alloc] initWithBufferSize:bufferSize]; + + XCTAssertTrue([ringBuffer loadBuffer:inBuffer + offset:0 + size:inDataSize + error:nil]); + + // State after load: [0.0, 0.0, 1.0, 2.0, 3.0] + + TFLFloatBuffer* outBuffer = ringBuffer.floatBuffer; + XCTAssertNotNil(outBuffer); + XCTAssertEqual(outBuffer.size, bufferSize); + + // Expected state after loading most recent elements of source buffer. + float expectedData[] = {0.0f, 0.0f, 1.0f, 2.0f, 3.0f}; + + for (int i = 0; i < bufferSize; i++) { + XCTAssertEqual(outBuffer.data[i], expectedData[i]); + } +} + +- (void)testLoadSucceedsByShiftingOutOldElements { + NSInteger initialDataSize = 4; + float initialArray[] = {1.0f, 2.0f, 3.0f, 4.0f}; + + TFLFloatBuffer* initialBuffer = + [[TFLFloatBuffer alloc] initWithData:&(initialArray[0]) + size:initialDataSize]; + + NSInteger bufferSize = 5; + TFLRingBuffer* ringBuffer = + [[TFLRingBuffer alloc] initWithBufferSize:bufferSize]; + + XCTAssertTrue([ringBuffer loadBuffer:initialBuffer + offset:0 + size:initialDataSize + error:nil]); + + // State after load: [0.0, 1.0, 2.0, 3.0, 4.0] + + NSInteger inDataSize = 3; + float inArray[] = {5, 6, 7}; + TFLFloatBuffer* inBuffer = + [[TFLFloatBuffer alloc] initWithData:&(inArray[0]) size:inDataSize]; + + XCTAssertTrue([ringBuffer loadBuffer:inBuffer + offset:0 + size:inDataSize + error:nil]); + + TFLFloatBuffer* outBuffer = ringBuffer.floatBuffer; + XCTAssertNotNil(outBuffer); + XCTAssertEqual(outBuffer.size, bufferSize); + + // Expected state after loading most recent elements of source buffer. + float expectedData[] = {3.0f, 4.0f, 5.0f, 6.0f, 7.0f}; + + for (int i = 0; i < bufferSize; i++) { + XCTAssertEqual(outBuffer.data[i], expectedData[i]); + } +} + +- (void)testLoadSucceedsWithMostRecentElements { + NSInteger initialDataSize = 5; + float initialArray[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}; + + TFLFloatBuffer* initialBuffer = + [[TFLFloatBuffer alloc] initWithData:&(initialArray[0]) + size:initialDataSize]; + + NSInteger bufferSize = 5; + TFLRingBuffer* ringBuffer = + [[TFLRingBuffer alloc] initWithBufferSize:bufferSize]; + + XCTAssertTrue([ringBuffer loadBuffer:initialBuffer + offset:0 + size:initialDataSize + error:nil]); + + // State after load: [1.0, 2.0, 3.0, 4.0, 5.0] + + NSInteger sourceDataSize = 6; + float sourceArray[] = {6, 7, 8, 9, 10, 11}; + TFLFloatBuffer* sourceBuffer = + [[TFLFloatBuffer alloc] initWithData:&(sourceArray[0]) + size:sourceDataSize]; + + XCTAssertTrue([ringBuffer loadBuffer:sourceBuffer + offset:0 + size:sourceDataSize + error:nil]); + + TFLFloatBuffer* outBuffer = ringBuffer.floatBuffer; + XCTAssertNotNil(outBuffer); + XCTAssertEqual(outBuffer.size, bufferSize); + + // Expected state after loading most recent elements of source buffer. + float expectedData[] = {7.0f, 8.0f, 9.0f, 10.0f, 11.0f}; + + for (int i = 0; i < bufferSize; i++) { + XCTAssertEqual(outBuffer.data[i], expectedData[i]); + } +} + +- (void)testLoadSucceedsWithOffseAndMostRecentElements { + NSInteger initialDataSize = 5; + float initialArray[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}; + + TFLFloatBuffer* initialBuffer = + [[TFLFloatBuffer alloc] initWithData:&(initialArray[0]) + size:initialDataSize]; + + NSInteger bufferSize = 5; + TFLRingBuffer* ringBuffer = + [[TFLRingBuffer alloc] initWithBufferSize:bufferSize]; + + XCTAssertTrue([ringBuffer loadBuffer:initialBuffer + offset:0 + size:initialDataSize + error:nil]); + + // State after load: [1.0, 2.0, 3.0, 4.0, 5.0] + + NSInteger totalInSize = 8; + float inArray[] = {6, 7, 8, 9, 10, 11, 12, 13}; + TFLFloatBuffer* inBuffer = + [[TFLFloatBuffer alloc] initWithData:&(inArray[0]) size:totalInSize]; + + NSInteger offset = 2; + NSInteger inDataSize = 6; + XCTAssertTrue([ringBuffer loadBuffer:inBuffer + offset:offset + size:inDataSize + error:nil]); + + TFLFloatBuffer* outBuffer = ringBuffer.floatBuffer; + XCTAssertNotNil(outBuffer); + XCTAssertEqual(outBuffer.size, bufferSize); + + // Expected state after load with most recent elements and offset. + float expectedData[] = {9.0f, 10.0f, 11.0f, 12.0f, 13.0f}; + + for (int i = 0; i < bufferSize; i++) { + XCTAssertEqual(outBuffer.data[i], expectedData[i]); + } +} + +- (void)testLoadSucceedsWithOffset { + NSInteger initialDataSize = 2; + float initialArray[] = {1.0f, 2.0f}; + + TFLFloatBuffer* initialBuffer = + [[TFLFloatBuffer alloc] initWithData:&(initialArray[0]) + size:initialDataSize]; + + NSInteger bufferSize = 5; + TFLRingBuffer* ringBuffer = + [[TFLRingBuffer alloc] initWithBufferSize:bufferSize]; + + XCTAssertTrue([ringBuffer loadBuffer:initialBuffer + offset:0 + size:initialDataSize + error:nil]); + + // State after load: [0.0, 0.0, 0.0, 1.0, 2.0] + + NSInteger totalInSize = 4; + float inArray[] = {6.0f, 7.0f, 8.0f, 9.0f}; + TFLFloatBuffer* inBuffer = + [[TFLFloatBuffer alloc] initWithData:&(inArray[0]) size:totalInSize]; + + NSInteger offset = 2; + NSInteger inDataSize = 2; + XCTAssertTrue([ringBuffer loadBuffer:inBuffer + offset:offset + size:inDataSize + error:nil]); + + TFLFloatBuffer* outBuffer = ringBuffer.floatBuffer; + XCTAssertNotNil(outBuffer); + XCTAssertEqual(outBuffer.size, bufferSize); + + // State after load with offset + float expectedData[] = {0.0f, 1.0f, 2.0f, 8.0f, 9.0f}; + + for (int i = 0; i < bufferSize; i++) { + XCTAssertEqual(outBuffer.data[i], expectedData[i]); + } +} + +- (void)testLoadFailsWithIndexOutofBounds { + NSInteger initialDataSize = 2; + float initialArray[] = {1.0f, 2.0f}; + + TFLFloatBuffer* initialBuffer = + [[TFLFloatBuffer alloc] initWithData:&(initialArray[0]) + size:initialDataSize]; + + NSInteger bufferSize = 5; + TFLRingBuffer* ringBuffer = + [[TFLRingBuffer alloc] initWithBufferSize:bufferSize]; + + XCTAssertTrue([ringBuffer loadBuffer:initialBuffer + offset:0 + size:initialDataSize + error:nil]); + + NSInteger totalInSize = 4; + float inArray[] = {6.0f, 7.0f, 8.0f, 9.0f}; + TFLFloatBuffer* inBuffer = + [[TFLFloatBuffer alloc] initWithData:&(inArray[0]) size:totalInSize]; + + NSInteger offset = 2; + NSInteger inDataSize = 3; + + NSError* error = nil; + XCTAssertFalse([ringBuffer loadBuffer:inBuffer + offset:offset + size:inDataSize + error:&error]); + + XCTAssertNotNil(error); + VerifyError(error, @"org.tensorflow.lite.tasks", + TFLSupportErrorCodeInvalidArgumentError, + @"offset + size exceeds the maximum size of the source buffer."); +} + +- (void)testClearSucceeds { + NSInteger initialDataSize = 2; + float initialArray[] = {1.0f, 2.0f}; + + TFLFloatBuffer* initialBuffer = + [[TFLFloatBuffer alloc] initWithData:&(initialArray[0]) + size:initialDataSize]; + + NSInteger bufferSize = 5; + TFLRingBuffer* ringBuffer = + [[TFLRingBuffer alloc] initWithBufferSize:bufferSize]; + + XCTAssertTrue([ringBuffer loadBuffer:initialBuffer + offset:0 + size:initialDataSize + error:nil]); + + [ringBuffer clear]; + + float expectedData[] = {0.0f, 0.0f, 0.0f, 0.0f, 0.0f}; + + TFLFloatBuffer* outBuffer = ringBuffer.floatBuffer; + XCTAssertNotNil(outBuffer); + XCTAssertEqual(outBuffer.size, bufferSize); + + for (int i = 0; i < bufferSize; i++) { + XCTAssertEqual(outBuffer.data[i], expectedData[i]); + } +} + +@end + +NS_ASSUME_NONNULL_END
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/image_classifier/TFLImageClassifierTests.swift b/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/image_classifier/TFLImageClassifierTests.swift index 3bd26419..0380391 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/image_classifier/TFLImageClassifierTests.swift +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/image_classifier/TFLImageClassifierTests.swift
@@ -17,22 +17,21 @@ @testable import TFLImageClassifier -class TFLImageClassifierTests: XCTestCase { +class ImageClassifierTests: XCTestCase { - static let bundle = Bundle(for: TFLImageClassifierTests.self) + static let bundle = Bundle(for: ImageClassifierTests.self) static let modelPath = bundle.path( forResource: "mobilenet_v2_1.0_224", ofType: "tflite") func testSuccessfullInferenceOnMLImageWithUIImage() throws { - let modelPath = try XCTUnwrap(TFLImageClassifierTests.modelPath) + let modelPath = try XCTUnwrap(ImageClassifierTests.modelPath) - let imageClassifierOptions = TFLImageClassifierOptions(modelPath: modelPath) - XCTAssertNotNil(imageClassifierOptions) + let imageClassifierOptions = ImageClassifierOptions(modelPath: modelPath) let imageClassifier = - try TFLImageClassifier.imageClassifier(options: imageClassifierOptions!) + try ImageClassifier.classifier(options: imageClassifierOptions) let gmlImage = try XCTUnwrap( MLImage.imageFromBundle( @@ -40,8 +39,8 @@ filename: "burger", type: "jpg")) - let classificationResults: TFLClassificationResult = - try imageClassifier.classify(gmlImage: gmlImage) + let classificationResults: ClassificationResult = + try imageClassifier.classify(mlImage: gmlImage) XCTAssertNotNil(classificationResults) XCTAssertEqual(classificationResults.classifications.count, 1) @@ -54,16 +53,15 @@ func testModelOptionsWithMaxResults() throws { - let modelPath = try XCTUnwrap(TFLImageClassifierTests.modelPath) + let modelPath = try XCTUnwrap(ImageClassifierTests.modelPath) - let imageClassifierOptions = TFLImageClassifierOptions(modelPath: modelPath) - XCTAssertNotNil(imageClassifierOptions) + let imageClassifierOptions = ImageClassifierOptions(modelPath: modelPath) let maxResults = 3 - imageClassifierOptions!.classificationOptions.maxResults = maxResults + imageClassifierOptions.classificationOptions.maxResults = maxResults let imageClassifier = - try TFLImageClassifier.imageClassifier(options: imageClassifierOptions!) + try ImageClassifier.classifier(options: imageClassifierOptions) let gmlImage = try XCTUnwrap( MLImage.imageFromBundle( @@ -71,8 +69,8 @@ filename: "burger", type: "jpg")) - let classificationResults: TFLClassificationResult = try imageClassifier.classify( - gmlImage: gmlImage) + let classificationResults: ClassificationResult = try imageClassifier.classify( + mlImage: gmlImage) XCTAssertNotNil(classificationResults) XCTAssertEqual(classificationResults.classifications.count, 1) @@ -86,13 +84,12 @@ func testInferenceWithBoundingBox() throws { - let modelPath = try XCTUnwrap(TFLImageClassifierTests.modelPath) + let modelPath = try XCTUnwrap(ImageClassifierTests.modelPath) - let imageClassifierOptions = TFLImageClassifierOptions(modelPath: modelPath) - XCTAssertNotNil(imageClassifierOptions) + let imageClassifierOptions = ImageClassifierOptions(modelPath: modelPath) let imageClassifier = - try TFLImageClassifier.imageClassifier(options: imageClassifierOptions!) + try ImageClassifier.classifier(options: imageClassifierOptions) let gmlImage = try XCTUnwrap( MLImage.imageFromBundle( @@ -102,7 +99,7 @@ let roi = CGRect(x: 406, y: 110, width: 148, height: 153) let classificationResults = - try imageClassifier.classify(gmlImage: gmlImage, regionOfInterest: roi) + try imageClassifier.classify(mlImage: gmlImage, regionOfInterest: roi) XCTAssertNotNil(classificationResults) XCTAssertEqual(classificationResults.classifications.count, 1) @@ -116,13 +113,12 @@ func testInferenceWithRGBAImage() throws { - let modelPath = try XCTUnwrap(TFLImageClassifierTests.modelPath) + let modelPath = try XCTUnwrap(ImageClassifierTests.modelPath) - let imageClassifierOptions = TFLImageClassifierOptions(modelPath: modelPath) - XCTAssertNotNil(imageClassifierOptions) + let imageClassifierOptions = ImageClassifierOptions(modelPath: modelPath) let imageClassifier = - try TFLImageClassifier.imageClassifier(options: imageClassifierOptions!) + try ImageClassifier.classifier(options: imageClassifierOptions) let gmlImage = try XCTUnwrap( MLImage.imageFromBundle( @@ -131,7 +127,7 @@ type: "png")) let classificationResults = - try imageClassifier.classify(gmlImage: gmlImage) + try imageClassifier.classify(mlImage: gmlImage) XCTAssertNotNil(classificationResults) XCTAssertEqual(classificationResults.classifications.count, 1)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/image_segmenter/TFLImageSegmenterTests.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/image_segmenter/TFLImageSegmenterTests.m index 39c3153..f483a51 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/image_segmenter/TFLImageSegmenterTests.m +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/image_segmenter/TFLImageSegmenterTests.m
@@ -52,8 +52,8 @@ // test method in the class. [super setUp]; self.modelPath = - [[NSBundle bundleForClass:[self class]] pathForResource:@"deeplabv3" - ofType:@"tflite"]; + [[NSBundle bundleForClass:self.class] pathForResource:@"deeplabv3" + ofType:@"tflite"]; XCTAssertNotNil(self.modelPath); } @@ -216,7 +216,7 @@ XCTAssertNotNil(imageSegmenter); GMLImage* gmlImage = - [GMLImage imageFromBundleWithClass:[self class] + [GMLImage imageFromBundleWithClass:self.class fileName:@"segmentation_input_rotation0" ofType:@"jpg"]; XCTAssertNotNil(gmlImage); @@ -236,7 +236,7 @@ XCTAssertTrue(segmentationResult.segmentations[0].categoryMask.mask != nil); GMLImage* goldenImage = - [GMLImage imageFromBundleWithClass:[self class] + [GMLImage imageFromBundleWithClass:self.class fileName:@"segmentation_golden_rotation0" ofType:@"png"];
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/image_segmenter/TFLImageSegmenterTests.swift b/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/image_segmenter/TFLImageSegmenterTests.swift index 8e92a53..af476a691 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/image_segmenter/TFLImageSegmenterTests.swift +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/image_segmenter/TFLImageSegmenterTests.swift
@@ -17,12 +17,12 @@ @testable import TFLImageSegmenter -class TFLImageSegmenterTests: XCTestCase { +class ImageSegmenterTests: XCTestCase { - static let bundle = Bundle(for: TFLImageSegmenterTests.self) + static let bundle = Bundle(for: ImageSegmenterTests.self) static let modelPath = bundle.path( forResource: "deeplabv3", - ofType: "tflite")! + ofType: "tflite") // The maximum fraction of pixels in the candidate mask that can have a // different class than the golden mask for the test to pass. @@ -38,7 +38,7 @@ let deepLabV3SegmentationHeight = 257 - func verifyDeeplabV3PartialSegmentationResult(_ coloredLabels: [TFLColoredLabel]) { + func verifyDeeplabV3PartialSegmentationResult(_ coloredLabels: [ColoredLabel]) { self.verifyColoredLabel( coloredLabels[0], @@ -189,7 +189,7 @@ } func verifyColoredLabel( - _ coloredLabel: TFLColoredLabel, + _ coloredLabel: ColoredLabel, expectedR: UInt, expectedG: UInt, expectedB: UInt, @@ -211,20 +211,20 @@ func testSuccessfullInferenceOnMLImageWithUIImage() throws { - let modelPath = try XCTUnwrap(TFLImageSegmenterTests.modelPath) + let modelPath = try XCTUnwrap(ImageSegmenterTests.modelPath) - let imageSegmenterOptions = try XCTUnwrap(TFLImageSegmenterOptions(modelPath: modelPath)) + let imageSegmenterOptions = ImageSegmenterOptions(modelPath: modelPath) let imageSegmenter = - try TFLImageSegmenter.imageSegmenter(options: imageSegmenterOptions) + try ImageSegmenter.segmenter(options: imageSegmenterOptions) let gmlImage = try XCTUnwrap( MLImage.imageFromBundle( class: type(of: self), filename: "segmentation_input_rotation0", type: "jpg")) - let segmentationResult: TFLSegmentationResult = - try XCTUnwrap(imageSegmenter.segment(gmlImage: gmlImage)) + let segmentationResult: SegmentationResult = + try XCTUnwrap(imageSegmenter.segment(mlImage: gmlImage)) XCTAssertEqual(segmentationResult.segmentations.count, 1)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/object_detector/TFLObjectDetectorTests.m b/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/object_detector/TFLObjectDetectorTests.m index 7dffe7e1..f7091a5 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/object_detector/TFLObjectDetectorTests.m +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/object_detector/TFLObjectDetectorTests.m
@@ -18,16 +18,22 @@ #import "tensorflow_lite_support/ios/task/vision/sources/TFLObjectDetector.h" #import "tensorflow_lite_support/ios/task/vision/utils/sources/GMLImage+Utils.h" -#define VerifyDetection(detection, expectedBoundingBox, expectedFirstScore, expectedFirstLabel) \ - XCTAssertGreaterThan([detection.categories count], 0); \ - NSLog(@"Detected %f", detection.categories[0].score); \ - NSLog(@"Expected %f", expectedFirstScore); \ - XCTAssertEqual(detection.boundingBox.origin.x, expectedBoundingBox.origin.x); \ - XCTAssertEqual(detection.boundingBox.origin.y, expectedBoundingBox.origin.y); \ - XCTAssertEqual(detection.boundingBox.size.width, expectedBoundingBox.size.width); \ - XCTAssertEqual(detection.boundingBox.size.height, expectedBoundingBox.size.height); \ - XCTAssertEqualObjects(detection.categories[0].label, expectedFirstLabel); \ - XCTAssertEqualWithAccuracy(detection.categories[0].score, expectedFirstScore, 0.001) +#define VerifyDetection(detection, expectedBoundingBox, expectedFirstScore, \ + expectedFirstLabel) \ + XCTAssertGreaterThan(detection.categories.count, 0); \ + NSLog(@"Detected %f", detection.categories[0].score); \ + NSLog(@"Expected %f", expectedFirstScore); \ + XCTAssertEqual(detection.boundingBox.origin.x, \ + expectedBoundingBox.origin.x); \ + XCTAssertEqual(detection.boundingBox.origin.y, \ + expectedBoundingBox.origin.y); \ + XCTAssertEqual(detection.boundingBox.size.width, \ + expectedBoundingBox.size.width); \ + XCTAssertEqual(detection.boundingBox.size.height, \ + expectedBoundingBox.size.height); \ + XCTAssertEqualObjects(detection.categories[0].label, expectedFirstLabel); \ + XCTAssertEqualWithAccuracy(detection.categories[0].score, \ + expectedFirstScore, 0.001) @interface TFLObjectDetectorTests : XCTestCase @property(nonatomic, nullable) NSString *modelPath; @@ -39,14 +45,14 @@ // Put setup code here. This method is called before the invocation of each test method in the // class. [super setUp]; - self.modelPath = [[NSBundle bundleForClass:[self class]] + self.modelPath = [[NSBundle bundleForClass:self.class] pathForResource:@"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29" ofType:@"tflite"]; XCTAssertNotNil(self.modelPath); } - (void)verifyResults:(TFLDetectionResult *)detectionResult { - XCTAssertGreaterThan([detectionResult.detections count], 0); + XCTAssertGreaterThan(detectionResult.detections.count, 0); VerifyDetection(detectionResult.detections[0], CGRectMake(54, 396, 393, 199), // expectedBoundingBox 0.632812, // expectedFirstScore @@ -77,7 +83,7 @@ [TFLObjectDetector objectDetectorWithOptions:objectDetectorOptions error:nil]; XCTAssertNotNil(objectDetector); - GMLImage *gmlImage = [GMLImage imageFromBundleWithClass:[self class] + GMLImage* gmlImage = [GMLImage imageFromBundleWithClass:self.class fileName:@"cats_and_dogs" ofType:@"jpg"]; XCTAssertNotNil(gmlImage); @@ -96,14 +102,14 @@ [TFLObjectDetector objectDetectorWithOptions:objectDetectorOptions error:nil]; XCTAssertNotNil(objectDetector); - GMLImage *gmlImage = [GMLImage imageFromBundleWithClass:[self class] + GMLImage* gmlImage = [GMLImage imageFromBundleWithClass:self.class fileName:@"cats_and_dogs" ofType:@"jpg"]; XCTAssertNotNil(gmlImage); TFLDetectionResult *detectionResult = [objectDetector detectWithGMLImage:gmlImage error:nil]; - XCTAssertLessThanOrEqual([detectionResult.detections count], maxResults); + XCTAssertLessThanOrEqual(detectionResult.detections.count, maxResults); VerifyDetection(detectionResult.detections[0], CGRectMake(54, 396, 393, 199), // expectedBoundingBox 0.632812, // expectedFirstScore
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/object_detector/TFLObjectDetectorTests.swift b/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/object_detector/TFLObjectDetectorTests.swift index d676b9e..c1014e6 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/object_detector/TFLObjectDetectorTests.swift +++ b/third_party/tflite_support/src/tensorflow_lite_support/ios/test/task/vision/object_detector/TFLObjectDetectorTests.swift
@@ -17,14 +17,14 @@ @testable import TFLObjectDetector -class TFLObjectDetectorTests: XCTestCase { +class ObjectDetectorTests: XCTestCase { - static let bundle = Bundle(for: TFLObjectDetectorTests.self) + static let bundle = Bundle(for: ObjectDetectorTests.self) static let modelPath = bundle.path( forResource: "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29", - ofType: "tflite")! + ofType: "tflite") - func verifyDetectionResult(_ detectionResult: TFLDetectionResult) { + func verifyDetectionResult(_ detectionResult: DetectionResult) { XCTAssertGreaterThan(detectionResult.detections.count, 0) self.verifyDetection( @@ -53,7 +53,7 @@ } func verifyDetection( - _ detection: TFLDetection, expectedBoundingBox: CGRect, + _ detection: Detection, expectedBoundingBox: CGRect, expectedFirstScore: Float, expectedFirstLabel: String ) { @@ -73,52 +73,50 @@ XCTAssertEqual( detection.categories[0].label, expectedFirstLabel) - XCTAssertEqualWithAccuracy( + XCTAssertEqual( detection.categories[0].score, expectedFirstScore, accuracy: 0.001) } func testSuccessfullInferenceOnMLImageWithUIImage() throws { - let modelPath = try XCTUnwrap(TFLObjectDetectorTests.modelPath) + let modelPath = try XCTUnwrap(ObjectDetectorTests.modelPath) - let objectDetectorOptions = TFLObjectDetectorOptions(modelPath: modelPath) - XCTAssertNotNil(objectDetectorOptions) + let objectDetectorOptions = ObjectDetectorOptions(modelPath: modelPath) let objectDetector = - try TFLObjectDetector.objectDetector(options: objectDetectorOptions!) + try ObjectDetector.detector(options: objectDetectorOptions) let gmlImage = try XCTUnwrap( MLImage.imageFromBundle( class: type(of: self), filename: "cats_and_dogs", type: "jpg")) - let detectionResults: TFLDetectionResult = - try objectDetector.detect(gmlImage: gmlImage) + let detectionResults: DetectionResult = + try objectDetector.detect(mlImage: gmlImage) self.verifyDetectionResult(detectionResults) } func testModelOptionsWithMaxResults() throws { - let modelPath = try XCTUnwrap(TFLObjectDetectorTests.modelPath) + let modelPath = try XCTUnwrap(ObjectDetectorTests.modelPath) - let objectDetectorOptions = TFLObjectDetectorOptions(modelPath: modelPath) - XCTAssertNotNil(objectDetectorOptions) + let objectDetectorOptions = ObjectDetectorOptions(modelPath: modelPath) let maxResults = 3 - objectDetectorOptions!.classificationOptions.maxResults = maxResults + objectDetectorOptions.classificationOptions.maxResults = maxResults let objectDetector = - try TFLObjectDetector.objectDetector(options: objectDetectorOptions!) + try ObjectDetector.detector(options: objectDetectorOptions) let gmlImage = try XCTUnwrap( MLImage.imageFromBundle( class: type(of: self), filename: "cats_and_dogs", type: "jpg")) - let detectionResult: TFLDetectionResult = try objectDetector.detect( - gmlImage: gmlImage) + let detectionResult: DetectionResult = try objectDetector.detect( + mlImage: gmlImage) XCTAssertLessThanOrEqual(detectionResult.detections.count, maxResults)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/AudioClassifier.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/AudioClassifier.java index b3eb11f..85c5d12e 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/AudioClassifier.java +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/AudioClassifier.java
@@ -24,7 +24,6 @@ import android.os.ParcelFileDescriptor; import org.tensorflow.lite.DataType; -import org.tensorflow.lite.annotations.UsedByReflection; import org.tensorflow.lite.support.audio.TensorAudio; import org.tensorflow.lite.support.audio.TensorAudio.TensorAudioFormat; import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; @@ -33,6 +32,7 @@ import org.tensorflow.lite.task.core.TaskJniUtils; import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider; import org.tensorflow.lite.task.core.TaskJniUtils.FdAndOptionsHandleProvider; +import org.tensorflow.lite.task.core.annotations.UsedByReflection; import java.io.File; import java.io.IOException;
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/Classifications.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/Classifications.java index 7d5b07f..8e82702 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/Classifications.java +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/Classifications.java
@@ -17,8 +17,8 @@ import com.google.auto.value.AutoValue; -import org.tensorflow.lite.annotations.UsedByReflection; import org.tensorflow.lite.support.label.Category; +import org.tensorflow.lite.task.core.annotations.UsedByReflection; import java.util.ArrayList; import java.util.Collections;
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/BUILD index 4f3e538..9a15bbb7 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/BUILD
@@ -7,12 +7,16 @@ android_library( name = "base-task-api", - srcs = glob(["**/*.java"]), + srcs = glob(["**/*.java"]) + [ + "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/processor:task_processor_src", + ], javacopts = ["-source 7 -target 7"], + proguard_specs = ["proguard.flags"], visibility = ["//visibility:public"], # LINT.IfChange(dep) deps = [ "@com_google_auto_value", + "@maven//:androidx_annotation_annotation", ], # LINT.ThenChange(<INTERNAL>/release/build_task_base_pom.sh:dep) )
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/annotations/UsedByReflection.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/annotations/UsedByReflection.java new file mode 100644 index 0000000..fb1dfec8 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/annotations/UsedByReflection.java
@@ -0,0 +1,31 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.task.core.annotations; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Target; + +/** + * Annotation used for marking methods and fields that are called by reflection. Useful for keeping + * components that would otherwise be removed by Proguard. Use the value parameter to mention a file + * that calls this method. + * + * @hide + */ +@Target({ElementType.METHOD, ElementType.FIELD, ElementType.TYPE, ElementType.CONSTRUCTOR}) +public @interface UsedByReflection { + String value(); +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/proguard.flags b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/proguard.flags new file mode 100644 index 0000000..7cd8254 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/proguard.flags
@@ -0,0 +1,5 @@ +-keep class org.tensorflow.lite.task.core.annotations.UsedByReflection +-keep @org.tensorflow.lite.task.core.annotations.UsedByReflection class * +-keepclassmembers class * { + @org.tensorflow.lite.task.core.annotations.UsedByReflection *; +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/processor/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/processor/BUILD new file mode 100644 index 0000000..0021f5e --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/processor/BUILD
@@ -0,0 +1,9 @@ +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +filegroup( + name = "task_processor_src", + srcs = glob(["**/*.java"]), +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/processor/NearestNeighbor.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/processor/NearestNeighbor.java new file mode 100644 index 0000000..a39247f --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/processor/NearestNeighbor.java
@@ -0,0 +1,53 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.task.processor; + +import com.google.auto.value.AutoValue; + +import org.tensorflow.lite.task.core.annotations.UsedByReflection; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +/** Represents the search result of a Searcher model. */ +@AutoValue +@UsedByReflection("searcher_jni.cc") +public abstract class NearestNeighbor { + @UsedByReflection("searcher_jni.cc") + static NearestNeighbor create(byte[] metadataArray, float distance) { + // Convert byte[] metadataArray to ByteBuffer which handles endianess better. + // + // Ideally, the API should accept a ByteBuffer instead of a byte[]. However, converting + // byte[] to ByteBuffer in JNI will lead to unnecessarily complex code which involves 6 more + // reflection calls. We can make this method package private, because users in general + // shouldn't need to create NearestNeighbor instances, but only consume the objects return + // from Task Library. This API will be used mostly for internal purpose. + ByteBuffer metadata = ByteBuffer.wrap(metadataArray); + metadata.order(ByteOrder.nativeOrder()); + return new AutoValue_NearestNeighbor(metadata, distance); + } + + /** + * Gets the user-defined metadata about the result. This could be a label, a unique ID, a + * serialized proto of some sort, etc. + * + * <p><b>Do not mutate</b> the returned metadata. + */ + public abstract ByteBuffer getMetadata(); + + /** Gets the distance score indicating how confident the result is. Lower is better. */ + public abstract float getDistance(); +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/processor/SearcherOptions.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/processor/SearcherOptions.java new file mode 100644 index 0000000..86f5fdde --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/processor/SearcherOptions.java
@@ -0,0 +1,83 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.task.processor; + +import androidx.annotation.Nullable; + +import com.google.auto.value.AutoValue; + +import java.io.File; + +/** Options to configure Searcher API. */ +@AutoValue +public abstract class SearcherOptions { + private static final boolean DEFAULT_L2_NORMALIZE = false; + private static final boolean DEFAULT_QUANTIZE = false; + private static final int DEFAULT_MAX_RESULTS = 5; + + public abstract boolean getL2Normalize(); + + public abstract boolean getQuantize(); + + @Nullable + public abstract File getIndexFile(); + + public abstract int getMaxResults(); + + public static Builder builder() { + return new AutoValue_SearcherOptions.Builder() + .setL2Normalize(DEFAULT_L2_NORMALIZE) + .setQuantize(DEFAULT_QUANTIZE) + .setIndexFile(null) + .setMaxResults(DEFAULT_MAX_RESULTS); + } + + /** Builder for {@link SearcherOptions}. */ + @AutoValue.Builder + public abstract static class Builder { + /** + * Sets whether to normalize the embedding feature vector with L2 norm. Defaults to false. + * + * <p>Use this option only if the model does not already contain a native L2_NORMALIZATION + * TFLite Op. In most cases, this is already the case and L2 norm is thus achieved through + * TFLite inference. + */ + public abstract Builder setL2Normalize(boolean l2Normalize); + + /** + * Sets whether the embedding should be quantized to bytes via scalar quantization. Defaults + * to false. + * + * <p>Embeddings are implicitly assumed to be unit-norm and therefore any dimension is + * guaranteed to have a value in {@code [-1.0, 1.0]}. Use the l2_normalize option if this is + * not the case. + */ + public abstract Builder setQuantize(boolean quantize); + + /** + * Sets the index file to search into. + * + * <p>Required if the model does not come with an index file inside. Otherwise, it can be + * ignore by setting to {@code null}. + */ + public abstract Builder setIndexFile(@Nullable File indexFile); + + /** Sets the maximum number of nearest neighbor results to return. Defaults to {@code 5} */ + public abstract Builder setMaxResults(int maxResults); + + public abstract SearcherOptions build(); + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/BUILD index 2d9837e..c885d9e 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/BUILD
@@ -13,6 +13,7 @@ srcs = [ "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier:nl_classifier_src", "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa:bert_question_answerer_src", + "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/searcher:text_searcher_src", ], javacopts = ["-source 7 -target 7"], manifest = "AndroidManifest.xml",
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifier.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifier.java index ce912c9..070b945e 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifier.java +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifier.java
@@ -20,12 +20,12 @@ import com.google.auto.value.AutoValue; -import org.tensorflow.lite.annotations.UsedByReflection; import org.tensorflow.lite.support.label.Category; import org.tensorflow.lite.task.core.BaseOptions; import org.tensorflow.lite.task.core.BaseTaskApi; import org.tensorflow.lite.task.core.TaskJniUtils; import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider; +import org.tensorflow.lite.task.core.annotations.UsedByReflection; import java.io.File; import java.io.IOException;
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/NLClassifier.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/NLClassifier.java index b8aa32be..5c3eb2c9e 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/NLClassifier.java +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/NLClassifier.java
@@ -22,12 +22,12 @@ import com.google.auto.value.AutoValue; -import org.tensorflow.lite.annotations.UsedByReflection; import org.tensorflow.lite.support.label.Category; import org.tensorflow.lite.task.core.BaseOptions; import org.tensorflow.lite.task.core.BaseTaskApi; import org.tensorflow.lite.task.core.TaskJniUtils; import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider; +import org.tensorflow.lite.task.core.annotations.UsedByReflection; import java.io.File; import java.io.IOException;
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/QaAnswer.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/QaAnswer.java index 955da99..50917c03 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/QaAnswer.java +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/QaAnswer.java
@@ -15,7 +15,7 @@ package org.tensorflow.lite.task.text.qa; -import org.tensorflow.lite.annotations.UsedByReflection; +import org.tensorflow.lite.task.core.annotations.UsedByReflection; /** * Answers to {@link QuestionAnswerer}. Contains information about the answer and its relative
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/searcher/AndroidManifest.xml b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/searcher/AndroidManifest.xml new file mode 100644 index 0000000..f390318 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/searcher/AndroidManifest.xml
@@ -0,0 +1,5 @@ +<?xml version="1.0" encoding="utf-8"?> +<manifest xmlns:android="http://schemas.android.com/apk/res/android" + package="org.tensorflow.lite.task.text.searcher"> + <uses-sdk android:minSdkVersion="21" android:targetSdkVersion="29"/> +</manifest>
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/searcher/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/searcher/BUILD new file mode 100644 index 0000000..e4f4981 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/searcher/BUILD
@@ -0,0 +1,53 @@ +load("@org_tensorflow//tensorflow/lite/java:aar_with_jni.bzl", "aar_with_jni") +load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "android_library_with_tflite") +load("@build_bazel_rules_android//android:rules.bzl", "android_library") + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +exports_files([ + "AndroidManifest.xml", +]) + +filegroup( + name = "text_searcher_src", + srcs = glob(["**/*.java"]), +) + +# Default target that uses built-in ops, plus the custom ops for Universal Sentence Encoder. +android_library_with_tflite( + name = "text_searcher", + tflite_exports = [ + "//tensorflow_lite_support/java/src/native/task/text/searcher:text_searcher_native", + ], + exports = [ + ":text_searcher_java", + ], +) + +# Java-only target, need to be used together with a native target similar to +# //third_party/tensorflow_lite_support/java/src/native/task/text/searcher:text_searcher_native", +# Use this target when you want to provide a MutableOpResolver with customized +# OPs and/or a subset of BuiltInOps to reduce binary size. +android_library( + name = "text_searcher_java", + srcs = glob(["*.java"]), + javacopts = ["-source 7 -target 7"], + manifest = "AndroidManifest.xml", + deps = [ + "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core:base_task_api", + "@com_google_auto_value", + "@maven//:androidx_annotation_annotation", + ], +) + +# AAR target for OSS release. +# +# bazel build -c opt --config=monolithic --config=android_arm64 --fat_apk_cpu=x86,x86_64,arm64-v8a,armeabi-v7a \ +# tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/searcher:text-searcher +aar_with_jni( + name = "text-searcher", + android_library = ":text_searcher", +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/searcher/TextSearcher.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/searcher/TextSearcher.java new file mode 100644 index 0000000..ea3b1b8 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/searcher/TextSearcher.java
@@ -0,0 +1,262 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.task.text.searcher; + +import android.content.Context; +import android.content.res.AssetFileDescriptor; +import android.os.ParcelFileDescriptor; + +import com.google.auto.value.AutoValue; + +import org.tensorflow.lite.task.core.BaseOptions; +import org.tensorflow.lite.task.core.BaseTaskApi; +import org.tensorflow.lite.task.core.TaskJniUtils; +import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider; +import org.tensorflow.lite.task.processor.NearestNeighbor; +import org.tensorflow.lite.task.processor.SearcherOptions; + +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.MappedByteBuffer; +import java.util.List; + +/** + * Performs similarity search on text string. + * + * <p>The API expects a TFLite model with optional, but strongly recommended, <a + * href="https://www.tensorflow.org/lite/convert/metadata">TFLite Model Metadata.</a>. + * + * <p>The API expects a TFLite model with metadata populated. The metadata should contain the + * following information: + * + * <ul> + * <li>For Bert based TFLite model: + * <ul> + * <li>3 input tensors of type kTfLiteString with names "ids", "mask" and "segment_ids". + * <li>input_process_units for Wordpiece/Sentencepiece Tokenizer + * <li>exactly one output tensor of type kTfLiteFloat32 + * </ul> + * <li>For Regex based TFLite model: + * <ul> + * <li>1 input tensor. + * <li>input_process_units for RegexTokenizer Tokenizer + * <li>exactly one output tensor of type kTfLiteFloat32 + * </ul> + * <li>For Universal Sentence Encoder based TFLite model: + * <ul> + * <li>3 input tensors with names "inp_text", "res_context" and "res_text" + * <li>2 output tensors with names "query_encoding" and "response_encoding" of type + * kTfLiteFloat32 + * </ul> + * </ul> + * + * <p>TODO(b/180502532): add pointer to example model. + * + * <p>TODO(b/222671076): add factory create methods without options, such as `createFromFile`, once + * the single file format (index file packed in the model) is supported. + */ +public final class TextSearcher extends BaseTaskApi { + private static final String TEXT_SEARCHER_NATIVE_LIB = "task_text_jni"; + private static final int OPTIONAL_FD_LENGTH = -1; + private static final int OPTIONAL_FD_OFFSET = -1; + + /** + * Creates an {@link TextSearcher} instance from {@link TextSearcherOptions}. + * + * @param modelPath path of the search model with metadata in the assets + * @throws IOException if an I/O error occurs when loading the tflite model or the index file + * @throws IllegalArgumentException if an argument is invalid + * @throws IllegalStateException if there is an internal error + * @throws RuntimeException if there is an otherwise unspecified error + */ + public static TextSearcher createFromFileAndOptions(Context context, String modelPath, + final TextSearcherOptions options) throws IOException { + try (AssetFileDescriptor assetFileDescriptor = context.getAssets().openFd(modelPath)) { + return createFromModelFdAndOptions( + /*modelDescriptor=*/assetFileDescriptor.getParcelFileDescriptor().getFd(), + /*modelDescriptorLength=*/assetFileDescriptor.getLength(), + /*modelDescriptorOffset=*/assetFileDescriptor.getStartOffset(), options); + } + } + + /** + * Creates an {@link TextSearcher} instance. + * + * @param modelFile the search model {@link File} instance + * @throws IOException if an I/O error occurs when loading the tflite model or the index file + * @throws IllegalArgumentException if an argument is invalid + * @throws IllegalStateException if there is an internal error + * @throws RuntimeException if there is an otherwise unspecified error + */ + public static TextSearcher createFromFileAndOptions( + File modelFile, final TextSearcherOptions options) throws IOException { + try (ParcelFileDescriptor descriptor = + ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { + return createFromModelFdAndOptions( + /*modelDescriptor=*/descriptor.getFd(), + /*modelDescriptorLength=*/OPTIONAL_FD_LENGTH, + /*modelDescriptorOffset=*/OPTIONAL_FD_OFFSET, options); + } + } + + /** + * Creates an {@link TextSearcher} instance with a model buffer and {@link TextSearcherOptions}. + * + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the search + * model + * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a + * {@link MappedByteBuffer} + * @throws IOException if an I/O error occurs when loading the index file + * @throws IllegalStateException if there is an internal error + * @throws RuntimeException if there is an otherwise unspecified error + */ + public static TextSearcher createFromBufferAndOptions( + final ByteBuffer modelBuffer, final TextSearcherOptions options) throws IOException { + if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) { + throw new IllegalArgumentException( + "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer."); + } + if (options.getSearcherOptions().getIndexFile() != null) { + try (ParcelFileDescriptor indexDescriptor = + ParcelFileDescriptor.open(options.getSearcherOptions().getIndexFile(), + ParcelFileDescriptor.MODE_READ_ONLY)) { + return createFromBufferAndOptionsImpl( + modelBuffer, options, indexDescriptor.getFd()); + } + } else { + return createFromBufferAndOptionsImpl(modelBuffer, options, /*indexFd=*/0); + } + } + + public static TextSearcher createFromBufferAndOptionsImpl( + final ByteBuffer modelBuffer, final TextSearcherOptions options, final int indexFd) { + return new TextSearcher(TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() { + @Override + public long createHandle() { + return initJniWithByteBuffer(modelBuffer, + TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()), + options.getSearcherOptions().getL2Normalize(), + options.getSearcherOptions().getQuantize(), indexFd, + options.getSearcherOptions().getMaxResults()); + } + }, TEXT_SEARCHER_NATIVE_LIB)); + } + + /** + * Constructor to initialize the JNI with a pointer from C++. + * + * @param nativeHandle a pointer referencing memory allocated in C++ + */ + TextSearcher(long nativeHandle) { + super(nativeHandle); + } + + /** Options for setting up an TextSearcher. */ + @AutoValue + public abstract static class TextSearcherOptions { + abstract BaseOptions getBaseOptions(); + + abstract SearcherOptions getSearcherOptions(); + + public static Builder builder() { + return new AutoValue_TextSearcher_TextSearcherOptions.Builder() + .setBaseOptions(BaseOptions.builder().build()) + .setSearcherOptions(SearcherOptions.builder().build()); + } + + /** Builder for {@link TextSearcherOptions}. */ + @AutoValue.Builder + public abstract static class Builder { + /** Sets the general options to configure Task APIs, such as accelerators. */ + public abstract Builder setBaseOptions(BaseOptions baseOptions); + + /** Sets the options to configure Searcher API. */ + public abstract Builder setSearcherOptions(SearcherOptions searcherOptions); + + public abstract TextSearcherOptions build(); + } + } + + /** + * Performs embedding extraction on the provided string input, followed by nearest-neighbor + * search in the index. + * + * @param text input text query to the model + */ + public List<NearestNeighbor> search(String text) { + return searchNative(getNativeHandle(), text); + } + + private static TextSearcher createFromModelFdAndOptions(final int modelDescriptor, + final long modelDescriptorLength, final long modelDescriptorOffset, + final TextSearcherOptions options) throws IOException { + if (options.getSearcherOptions().getIndexFile() != null) { + // indexDescriptor must be alive before TextSearcher is initialized completely in the + // native layer. + try (ParcelFileDescriptor indexDescriptor = + ParcelFileDescriptor.open(options.getSearcherOptions().getIndexFile(), + ParcelFileDescriptor.MODE_READ_ONLY)) { + return createFromModelFdAndOptionsImpl(modelDescriptor, modelDescriptorLength, + modelDescriptorOffset, options, indexDescriptor.getFd()); + } + } else { + // Index file is not configured. We'll check if the model contains one in the native + // layer. + return createFromModelFdAndOptionsImpl(modelDescriptor, modelDescriptorLength, + modelDescriptorOffset, options, /*indexFd=*/0); + } + } + + private static TextSearcher createFromModelFdAndOptionsImpl(final int modelDescriptor, + final long modelDescriptorLength, final long modelDescriptorOffset, + final TextSearcherOptions options, final int indexFd) { + long nativeHandle = TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() { + @Override + public long createHandle() { + return initJniWithModelFdAndOptions(modelDescriptor, modelDescriptorLength, + modelDescriptorOffset, + TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()), + options.getSearcherOptions().getL2Normalize(), + options.getSearcherOptions().getQuantize(), indexFd, + options.getSearcherOptions().getMaxResults()); + } + }, TEXT_SEARCHER_NATIVE_LIB); + return new TextSearcher(nativeHandle); + } + + private static native long initJniWithModelFdAndOptions(int modelDescriptor, + long modelDescriptorLength, long modelDescriptorOffset, long baseOptionsHandle, + boolean l2Normalize, boolean quantize, int indexDescriptor, int maxResults); + + private static native long initJniWithByteBuffer(ByteBuffer modelBuffer, long baseOptionsHandle, + boolean l2Normalize, boolean quantize, int indexFileDescriptor, int maxResults); + + /** The native method to search an input text string. */ + private static native List<NearestNeighbor> searchNative(long nativeHandle, String text); + + @Override + protected void deinit(long nativeHandle) { + deinitJni(nativeHandle); + } + + /** + * Native implementation to release memory pointed by the pointer. + * + * @param nativeHandle pointer to memory allocated + */ + private native void deinitJni(long nativeHandle); +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/BUILD index 70657c1..172f6c4 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/BUILD
@@ -16,6 +16,7 @@ "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier:image_classifier_src", "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/core:base_vision_api_src", "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector:object_detector_src", + "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/searcher:image_searcher_src", "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter:image_segmenter_src", ], javacopts = ["-source 7 -target 7"],
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/BUILD index a4d48cda..f35dc6c 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/BUILD
@@ -18,6 +18,13 @@ ), ) +filegroup( + name = "image_classifier_src_experimental", + srcs = glob( + ["**/*.java"], + ), +) + # Default target that uses BuiltInOpResolver, registers all built-in OPs. # IMPORTANT: In order to use hardware acceleration delegates, you must # additionally link to the appropriate delegate plugin target as follows:
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/Classifications.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/Classifications.java index 0d35443..e59a2e8 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/Classifications.java +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/Classifications.java
@@ -17,8 +17,8 @@ import com.google.auto.value.AutoValue; -import org.tensorflow.lite.annotations.UsedByReflection; import org.tensorflow.lite.support.label.Category; +import org.tensorflow.lite.task.core.annotations.UsedByReflection; import java.util.ArrayList; import java.util.Collections;
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/ImageClassifier.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/ImageClassifier.java index 48038f6..5b5be73b 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/ImageClassifier.java +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/ImageClassifier.java
@@ -21,13 +21,13 @@ import com.google.android.odml.image.MlImage; -import org.tensorflow.lite.annotations.UsedByReflection; import org.tensorflow.lite.support.image.MlImageAdapter; import org.tensorflow.lite.support.image.TensorImage; import org.tensorflow.lite.task.core.BaseOptions; import org.tensorflow.lite.task.core.TaskJniUtils; import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider; import org.tensorflow.lite.task.core.TaskJniUtils.FdAndOptionsHandleProvider; +import org.tensorflow.lite.task.core.annotations.UsedByReflection; import org.tensorflow.lite.task.core.vision.ImageProcessingOptions; import org.tensorflow.lite.task.vision.core.BaseVisionTaskApi; import org.tensorflow.lite.task.vision.core.BaseVisionTaskApi.InferenceProvider;
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/Detection.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/Detection.java index 7106fe8..096af52 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/Detection.java +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/Detection.java
@@ -19,8 +19,8 @@ import com.google.auto.value.AutoValue; -import org.tensorflow.lite.annotations.UsedByReflection; import org.tensorflow.lite.support.label.Category; +import org.tensorflow.lite.task.core.annotations.UsedByReflection; import java.util.ArrayList; import java.util.Collections;
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/ObjectDetector.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/ObjectDetector.java index c0585b8..d1fb421f 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/ObjectDetector.java +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/ObjectDetector.java
@@ -20,13 +20,13 @@ import com.google.android.odml.image.MlImage; -import org.tensorflow.lite.annotations.UsedByReflection; import org.tensorflow.lite.support.image.MlImageAdapter; import org.tensorflow.lite.support.image.TensorImage; import org.tensorflow.lite.task.core.BaseOptions; import org.tensorflow.lite.task.core.TaskJniUtils; import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider; import org.tensorflow.lite.task.core.TaskJniUtils.FdAndOptionsHandleProvider; +import org.tensorflow.lite.task.core.annotations.UsedByReflection; import org.tensorflow.lite.task.core.vision.ImageProcessingOptions; import org.tensorflow.lite.task.vision.core.BaseVisionTaskApi;
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/searcher/AndroidManifest.xml b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/searcher/AndroidManifest.xml new file mode 100644 index 0000000..0795a44 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/searcher/AndroidManifest.xml
@@ -0,0 +1,5 @@ +<?xml version="1.0" encoding="utf-8"?> +<manifest xmlns:android="http://schemas.android.com/apk/res/android" + package="org.tensorflow.lite.task.vision.searcher"> + <uses-sdk android:minSdkVersion="21" android:targetSdkVersion="29"/> +</manifest>
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/searcher/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/searcher/BUILD new file mode 100644 index 0000000..fae7ce30 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/searcher/BUILD
@@ -0,0 +1,58 @@ +load("@org_tensorflow//tensorflow/lite/java:aar_with_jni.bzl", "aar_with_jni") +load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "android_library_with_tflite") +load("@build_bazel_rules_android//android:rules.bzl", "android_library") + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +exports_files([ + "AndroidManifest.xml", +]) + +filegroup( + name = "image_searcher_src", + srcs = glob(["**/*.java"]), +) + +# Default target that uses BuiltInOpResolver, registers all built-in OPs. +android_library_with_tflite( + name = "image_searcher", + tflite_exports = [ + "//tensorflow_lite_support/java/src/native/task/vision/searcher:image_searcher_native", + ], + exports = [ + ":image_searcher_java", + ], +) + +# Java-only target, need to be used together with a native target similar to +# //third_party/tensorflow_lite_support/java/src/native/task/vision/searcher:image_searcher_native", +# Use this target when you want to provide a MutableOpResolver with customized +# OPs and/or a subset of BuiltInOps to reduce binary size. +android_library( + name = "image_searcher_java", + srcs = glob(["*.java"]) + [ + "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/core:base_vision_api_src", + ], + javacopts = ["-source 7 -target 7"], + manifest = "AndroidManifest.xml", + deps = [ + "//tensorflow_lite_support/java:tensorflowlite_support_java", + "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core:base_task_api", + "@com_google_auto_value", + "@maven//:androidx_annotation_annotation", + "@maven//:com_google_android_odml_image", + "@org_tensorflow//tensorflow/lite/java:tensorflowlite_java_stable", + ], +) + +# AAR target for OSS release. +# +# bazel build -c opt --config=monolithic --config=android_arm64 --fat_apk_cpu=x86,x86_64,arm64-v8a,armeabi-v7a \ +# tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/searcher:image-searcher +aar_with_jni( + name = "image-searcher", + android_library = ":image_searcher", +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/searcher/ImageSearcher.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/searcher/ImageSearcher.java new file mode 100644 index 0000000..d3d1e6a --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/searcher/ImageSearcher.java
@@ -0,0 +1,360 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.task.vision.searcher; + +import android.content.Context; +import android.content.res.AssetFileDescriptor; +import android.graphics.Rect; +import android.os.ParcelFileDescriptor; + +import com.google.android.odml.image.MlImage; +import com.google.auto.value.AutoValue; + +import org.tensorflow.lite.support.image.MlImageAdapter; +import org.tensorflow.lite.support.image.TensorImage; +import org.tensorflow.lite.task.core.BaseOptions; +import org.tensorflow.lite.task.core.TaskJniUtils; +import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider; +import org.tensorflow.lite.task.core.vision.ImageProcessingOptions; +import org.tensorflow.lite.task.processor.NearestNeighbor; +import org.tensorflow.lite.task.processor.SearcherOptions; +import org.tensorflow.lite.task.vision.core.BaseVisionTaskApi; +import org.tensorflow.lite.task.vision.core.BaseVisionTaskApi.InferenceProvider; + +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.MappedByteBuffer; +import java.util.List; + +/** + * Performs similarity search on images. + * + * <p>The API expects a TFLite model with optional, but strongly recommended, <a + * href="https://www.tensorflow.org/lite/convert/metadata">TFLite Model Metadata.</a>. + * + * <ul> + * <li>Input image tensor ({@code kTfLiteUInt8}/{@code kTfLiteFloat32}) + * <ul> + * <li>image input of size {@code [batch x height x width x channels]}. + * <li>batch inference is not supported ({@code batch} is required to be 1). + * <li>only RGB inputs are supported ({@code channels} is required to be 3). + * <li>if type is {@code kTfLiteFloat32}, NormalizationOptions are required to be attached + * to the metadata for input normalization. + * </ul> + * <li>Output tensor ({@code kTfLiteUInt8}/{@code kTfLiteFloat32}) + * <ul> + * <li>{@code N} components corresponding to the {@code N} dimensions of the returned + * feature vector for this output layer. + * <li>Either 2 or 4 dimensions, i.e. {@code [1 x N]} or {@code [1 x 1 x 1 x N]}. + * </ul> + * </ul> + * + * <p>TODO(b/180502532): add pointer to example model. + * + * <p>TODO(b/222671076): add factory create methods without options, such as `createFromFile`, once + * the single file format (index file packed in the model) is supported. + */ +public final class ImageSearcher extends BaseVisionTaskApi { + private static final String IMAGE_SEARCHER_NATIVE_LIB = "task_vision_jni"; + private static final int OPTIONAL_FD_LENGTH = -1; + private static final int OPTIONAL_FD_OFFSET = -1; + + /** + * Creates an {@link ImageSearcher} instance from {@link ImageSearcherOptions}. + * + * @param modelPath path of the search model with metadata in the assets + * @throws IOException if an I/O error occurs when loading the tflite model or the index file + * @throws IllegalArgumentException if an argument is invalid + * @throws IllegalStateException if there is an internal error + * @throws RuntimeException if there is an otherwise unspecified error + */ + public static ImageSearcher createFromFileAndOptions(Context context, String modelPath, + final ImageSearcherOptions options) throws IOException { + try (AssetFileDescriptor assetFileDescriptor = context.getAssets().openFd(modelPath)) { + return createFromModelFdAndOptions( + /*modelDescriptor=*/assetFileDescriptor.getParcelFileDescriptor().getFd(), + /*modelDescriptorLength=*/assetFileDescriptor.getLength(), + /*modelDescriptorOffset=*/assetFileDescriptor.getStartOffset(), options); + } + } + + /** + * Creates an {@link ImageSearcher} instance. + * + * @param modelFile the search model {@link File} instance + * @throws IOException if an I/O error occurs when loading the tflite model or the index file + * @throws IllegalArgumentException if an argument is invalid + * @throws IllegalStateException if there is an internal error + * @throws RuntimeException if there is an otherwise unspecified error + */ + public static ImageSearcher createFromFileAndOptions( + File modelFile, final ImageSearcherOptions options) throws IOException { + try (ParcelFileDescriptor descriptor = + ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { + return createFromModelFdAndOptions( + /*modelDescriptor=*/descriptor.getFd(), + /*modelDescriptorLength=*/OPTIONAL_FD_LENGTH, + /*modelDescriptorOffset=*/OPTIONAL_FD_OFFSET, options); + } + } + + /** + * Creates an {@link ImageSearcher} instance with a model buffer and {@link + * ImageSearcherOptions}. + * + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the search + * model + * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a + * {@link MappedByteBuffer} + * @throws IOException if an I/O error occurs when loading the index file + * @throws IllegalStateException if there is an internal error + * @throws RuntimeException if there is an otherwise unspecified error + */ + public static ImageSearcher createFromBufferAndOptions( + final ByteBuffer modelBuffer, final ImageSearcherOptions options) throws IOException { + if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) { + throw new IllegalArgumentException( + "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer."); + } + if (options.getSearcherOptions().getIndexFile() != null) { + try (ParcelFileDescriptor indexDescriptor = + ParcelFileDescriptor.open(options.getSearcherOptions().getIndexFile(), + ParcelFileDescriptor.MODE_READ_ONLY)) { + return createFromBufferAndOptionsImpl( + modelBuffer, options, indexDescriptor.getFd()); + } + } else { + return createFromBufferAndOptionsImpl(modelBuffer, options, /*indexFd=*/0); + } + } + + public static ImageSearcher createFromBufferAndOptionsImpl( + final ByteBuffer modelBuffer, final ImageSearcherOptions options, final int indexFd) { + return new ImageSearcher(TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() { + @Override + public long createHandle() { + return initJniWithByteBuffer(modelBuffer, + TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()), + options.getSearcherOptions().getL2Normalize(), + options.getSearcherOptions().getQuantize(), indexFd, + options.getSearcherOptions().getMaxResults()); + } + }, IMAGE_SEARCHER_NATIVE_LIB)); + } + + /** + * Constructor to initialize the JNI with a pointer from C++. + * + * @param nativeHandle a pointer referencing memory allocated in C++ + */ + ImageSearcher(long nativeHandle) { + super(nativeHandle); + } + + /** Options for setting up an ImageSearcher. */ + @AutoValue + public abstract static class ImageSearcherOptions { + abstract BaseOptions getBaseOptions(); + + abstract SearcherOptions getSearcherOptions(); + + public static Builder builder() { + return new AutoValue_ImageSearcher_ImageSearcherOptions.Builder() + .setBaseOptions(BaseOptions.builder().build()) + .setSearcherOptions(SearcherOptions.builder().build()); + } + + /** Builder for {@link ImageSearcherOptions}. */ + @AutoValue.Builder + public abstract static class Builder { + /** Sets the general options to configure Task APIs, such as accelerators. */ + public abstract Builder setBaseOptions(BaseOptions baseOptions); + + /** Sets the options to configure Searcher API. */ + public abstract Builder setSearcherOptions(SearcherOptions searcherOptions); + + public abstract ImageSearcherOptions build(); + } + } + + /** + * Performs embedding extraction on the provided {@link TensorImage}, followed by + * nearest-neighbor search in the index. + * + * <p>{@link ImageSearcher} supports the following {@link TensorImage} color space types: + * + * <ul> + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB} + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12} + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21} + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12} + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21} + * </ul> + * + * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image + * @throws IllegalArgumentException if the color space type of image is unsupported + */ + public List<NearestNeighbor> search(TensorImage image) { + return search(image, ImageProcessingOptions.builder().build()); + } + + /** + * Performs embedding extraction on the provided {@link TensorImage} with {@link + * ImageProcessingOptions}, followed by nearest-neighbor search in the index. + * + * <p>{@link ImageSearcher} supports the following options: + * + * <ul> + * <li>Region of interest (ROI) (through {@link ImageProcessingOptions.Builder#setRoi}). It + * defaults to the entire image. + * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It + * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}. + * </ul> + * + * <p>{@link ImageSearcher} supports the following {@link TensorImage} color space types: + * + * <ul> + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#RGB} + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV12} + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#NV21} + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV12} + * <li>{@link org.tensorflow.lite.support.image.ColorSpaceType#YV21} + * </ul> + * + * @param image a UINT8 {@link TensorImage} object that represents an RGB or YUV image + * @throws IllegalArgumentException if the color space type of image is unsupported + */ + public List<NearestNeighbor> search(TensorImage image, ImageProcessingOptions options) { + return run(new InferenceProvider<List<NearestNeighbor>>() { + @Override + public List<NearestNeighbor> run( + long frameBufferHandle, int width, int height, ImageProcessingOptions options) { + return search(frameBufferHandle, width, height, options); + } + }, image, options); + } + + /** + * Performs embedding extraction on the provided {@code MlImage}, followed by nearest-neighbor + * search in the index. + * + * @param image an {@code MlImage} object that represents an image + * @throws IllegalArgumentException if the storage type or format of the image is unsupported + */ + public List<NearestNeighbor> search(MlImage image) { + return search(image, ImageProcessingOptions.builder().build()); + } + + /** + * Performs embedding extraction on the provided {@code MlImage} with {@link + * ImageProcessingOptions}, followed by nearest-neighbor search in the index. + * + * <p>{@link ImageSearcher} supports the following options: + * + * <ul> + * <li>Region of interest (ROI) (through {@link ImageProcessingOptions.Builder#setRoi}). It + * defaults to the entire image. + * <li>image rotation (through {@link ImageProcessingOptions.Builder#setOrientation}). It + * defaults to {@link ImageProcessingOptions.Orientation#TOP_LEFT}. {@link + * MlImage#getRotation()} is not effective. + * </ul> + * + * @param image a {@code MlImage} object that represents an image + * @param options configures options including ROI and rotation + * @throws IllegalArgumentException if the storage type or format of the image is unsupported + */ + public List<NearestNeighbor> search(MlImage image, ImageProcessingOptions options) { + image.getInternal().acquire(); + TensorImage tensorImage = MlImageAdapter.createTensorImageFrom(image); + List<NearestNeighbor> result = search(tensorImage, options); + image.close(); + return result; + } + + private List<NearestNeighbor> search( + long frameBufferHandle, int width, int height, ImageProcessingOptions options) { + checkNotClosed(); + Rect roi = options.getRoi().isEmpty() ? new Rect(0, 0, width, height) : options.getRoi(); + return searchNative(getNativeHandle(), frameBufferHandle, + new int[] {roi.left, roi.top, roi.width(), roi.height()}); + } + + private static ImageSearcher createFromModelFdAndOptions(final int modelDescriptor, + final long modelDescriptorLength, final long modelDescriptorOffset, + final ImageSearcherOptions options) throws IOException { + if (options.getSearcherOptions().getIndexFile() != null) { + // indexDescriptor must be alive before ImageSearcher is initialized completely in the + // native layer. + try (ParcelFileDescriptor indexDescriptor = + ParcelFileDescriptor.open(options.getSearcherOptions().getIndexFile(), + ParcelFileDescriptor.MODE_READ_ONLY)) { + return createFromModelFdAndOptionsImpl(modelDescriptor, modelDescriptorLength, + modelDescriptorOffset, options, indexDescriptor.getFd()); + } + } else { + // Index file is not configured. We'll check if the model contains one in the native + // layer. + return createFromModelFdAndOptionsImpl(modelDescriptor, modelDescriptorLength, + modelDescriptorOffset, options, /*indexFd=*/0); + } + } + + private static ImageSearcher createFromModelFdAndOptionsImpl(final int modelDescriptor, + final long modelDescriptorLength, final long modelDescriptorOffset, + final ImageSearcherOptions options, final int indexFd) { + long nativeHandle = TaskJniUtils.createHandleFromLibrary(new EmptyHandleProvider() { + @Override + public long createHandle() { + return initJniWithModelFdAndOptions(modelDescriptor, modelDescriptorLength, + modelDescriptorOffset, + TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()), + options.getSearcherOptions().getL2Normalize(), + options.getSearcherOptions().getQuantize(), indexFd, + options.getSearcherOptions().getMaxResults()); + } + }, IMAGE_SEARCHER_NATIVE_LIB); + return new ImageSearcher(nativeHandle); + } + + private static native long initJniWithModelFdAndOptions(int modelDescriptor, + long modelDescriptorLength, long modelDescriptorOffset, long baseOptionsHandle, + boolean l2Normalize, boolean quantize, int indexDescriptor, int maxResults); + + private static native long initJniWithByteBuffer(ByteBuffer modelBuffer, long baseOptionsHandle, + boolean l2Normalize, boolean quantize, int indexFileDescriptor, int maxResults); + + /** + * The native method to search an image based on the ROI specified. + * + * @param roi the ROI of the input image, an array representing the bounding box as {left, top, + * width, height} + */ + private static native List<NearestNeighbor> searchNative( + long nativeHandle, long frameBufferHandle, int[] roi); + + @Override + protected void deinit(long nativeHandle) { + deinitJni(nativeHandle); + } + + /** + * Native implementation to release memory pointed by the pointer. + * + * @param nativeHandle pointer to memory allocated + */ + private native void deinitJni(long nativeHandle); +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ColoredLabel.java b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ColoredLabel.java index 991fede..7a7a5b32 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ColoredLabel.java +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ColoredLabel.java
@@ -22,7 +22,7 @@ import com.google.auto.value.AutoValue; -import org.tensorflow.lite.annotations.UsedByReflection; +import org.tensorflow.lite.task.core.annotations.UsedByReflection; /** Represents a label associated with a color for display purposes. */ @AutoValue
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/audio/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/audio/BUILD index 120b396..714cd5f 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/audio/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/audio/BUILD
@@ -14,11 +14,18 @@ jni_binary_with_tflite( name = "libtask_audio_jni.so", + linkscript = "//tensorflow_lite_support/java:default_version_script.lds", + tflite_deps = [ + ":task_audio_jni_lib", + ], +) + +cc_library_with_tflite( + name = "task_audio_jni_lib", srcs = [ "//tensorflow_lite_support/java/src/native/task/audio/classifier:audio_classifier_jni.cc", "//tensorflow_lite_support/java/src/native/task/core:task_jni_utils.cc", ], - linkscript = "//tensorflow_lite_support/java:default_version_script.lds", tflite_deps = [ "//tensorflow_lite_support/java/src/native/task/core:builtin_op_resolver", "//tensorflow_lite_support/cc/task/audio:audio_classifier",
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/BUILD index 88498f9..213e67c9 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/BUILD
@@ -19,19 +19,26 @@ "//tensorflow_lite_support/java/src/native/task/text/nlclassifier:nl_classifier_jni.cc", "//tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert:bert_nl_classifier_jni.cc", "//tensorflow_lite_support/java/src/native/task/text/qa:bert_question_answerer_jni.cc", + "//tensorflow_lite_support/java/src/native/task/text/searcher:text_searcher_jni.cc", ], linkscript = "//tensorflow_lite_support/java:default_version_script.lds", tflite_deps = [ "//tensorflow_lite_support/cc/task/text/nlclassifier:nl_classifier", "//tensorflow_lite_support/cc/task/text:bert_nl_classifier", "//tensorflow_lite_support/cc/task/text:bert_question_answerer", + "//tensorflow_lite_support/cc/task/text:text_searcher", "//tensorflow_lite_support/cc/utils:jni_utils", "//tensorflow_lite_support/java/src/native/task/text/nlclassifier:nl_classifier_jni_utils", - "//tensorflow_lite_support/java/src/native/task/core:builtin_op_resolver", + # Pack the universal_sentence_encoder_qa_op_resolver (built-in + USE custom ops) + # to the Task Text Java Library by default. + "//tensorflow_lite_support/java/src/native/task/text/searcher:universal_sentence_encoder_qa_op_register", ], deps = [ + "//tensorflow_lite_support/cc/port:statusor", "//tensorflow_lite_support/cc/task/core/proto:base_options_proto_inc", + "//tensorflow_lite_support/cc/task/processor/proto:search_result_cc_proto", "//tensorflow_lite_support/cc/task/text/proto:bert_nl_classifier_options_proto_inc", + "//tensorflow_lite_support/cc/task/text/proto:text_searcher_options_cc_proto", "//tensorflow_lite_support/java/jni", "@org_tensorflow//tensorflow/lite:op_resolver", "@org_tensorflow//tensorflow/lite/kernels:kernel_util",
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc index c358bee..4413918f 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc
@@ -37,8 +37,8 @@ jmethodID category_init = env->GetMethodID(category_class, "<init>", "(Ljava/lang/String;F)V"); - return ConvertVectorToArrayList<Category>( - env, results, + return ConvertVectorToArrayList( + env, results.begin(), results.end(), [env, category_class, category_init](const Category& category) { jstring class_name = env->NewStringUTF(category.class_name.data()); // Convert double to float as Java interface exposes float as scores.
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/qa/bert_question_answerer_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/qa/bert_question_answerer_jni.cc index 401e6fbd..b77746a 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/qa/bert_question_answerer_jni.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/qa/bert_question_answerer_jni.cc
@@ -156,8 +156,8 @@ jmethodID qa_answer_ctor = env->GetMethodID(qa_answer_class, "<init>", "(Ljava/lang/String;IIF)V"); - return ConvertVectorToArrayList<QaAnswer>( - env, results, + return ConvertVectorToArrayList( + env, results.begin(), results.end(), [env, qa_answer_class, qa_answer_ctor](const QaAnswer& ans) { jstring text = env->NewStringUTF(ans.text.data()); jobject qa_answer =
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/searcher/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/searcher/BUILD new file mode 100644 index 0000000..7d1a43e --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/searcher/BUILD
@@ -0,0 +1,58 @@ +load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite", "jni_binary_with_tflite") + +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +) + +exports_files(["text_searcher_jni.cc"]) + +cc_library_with_tflite( + name = "text_searcher_native", + tflite_jni_binaries = [ + ":libtask_text_jni.so", + ], +) + +jni_binary_with_tflite( + name = "libtask_text_jni.so", + linkscript = "//tensorflow_lite_support/java:default_version_script.lds", + tflite_deps = [ + ":native_without_resolver", + # Pack the universal_sentence_encoder_qa_op_resolver (built-in + USE custom ops) + # to the Task Java Library by default. + # Use `native_without_resolver` if a custom set of ops is preferred. + ":universal_sentence_encoder_qa_op_register", + ], +) + +cc_library_with_tflite( + name = "universal_sentence_encoder_qa_op_register", + srcs = [ + "universal_sentence_encoder_qa_op_register.cc", + ], + tflite_deps = [ + "//tensorflow_lite_support/examples/task/text/desktop:universal_sentence_encoder_qa_op_resolver", + ], +) + +# Shared native logic for TextSearcher. Combine this target and customized +# version of op_resolver to build customized text_searcher_native target. +cc_library_with_tflite( + name = "native_without_resolver", + srcs = [ + "text_searcher_jni.cc", + "//tensorflow_lite_support/java/src/native/task/core:task_jni_utils.cc", + ], + tflite_deps = [ + "//tensorflow_lite_support/cc/task/text:text_searcher", + "//tensorflow_lite_support/cc/utils:jni_utils", + ], + deps = [ + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/task/core/proto:base_options_proto_inc", + "//tensorflow_lite_support/cc/task/processor/proto:search_result_cc_proto", + "//tensorflow_lite_support/cc/task/text/proto:text_searcher_options_cc_proto", + "//tensorflow_lite_support/java/jni", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/searcher/text_searcher_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/searcher/text_searcher_jni.cc new file mode 100644 index 0000000..c207755 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/searcher/text_searcher_jni.cc
@@ -0,0 +1,200 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <jni.h> + +#include <memory> + +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/core/proto/base_options_proto_inc.h" +#include "tensorflow_lite_support/cc/task/processor/proto/search_result.pb.h" +#include "tensorflow_lite_support/cc/task/text/proto/text_searcher_options.pb.h" +#include "tensorflow_lite_support/cc/task/text/text_searcher.h" +#include "tensorflow_lite_support/cc/utils/jni_utils.h" + +namespace tflite { +namespace task { +// To be provided by a link-time library +extern std::unique_ptr<OpResolver> CreateOpResolver(); + +} // namespace task +} // namespace tflite + +namespace { + +using ::tflite::support::StatusOr; +using ::tflite::support::utils::ConvertVectorToArrayList; +using ::tflite::support::utils::CreateByteArray; +using ::tflite::support::utils::GetExceptionClassNameForStatusCode; +using ::tflite::support::utils::JStringToString; +using ::tflite::support::utils::kInvalidPointer; +using ::tflite::support::utils::ThrowException; +using ::tflite::task::core::BaseOptions; +using ::tflite::task::processor::NearestNeighbor; +using ::tflite::task::processor::SearchResult; +using ::tflite::task::text::TextSearcher; +using ::tflite::task::text::TextSearcherOptions; + +// Creates an TextSearcherOptions proto based on the Java class. +TextSearcherOptions ConvertToProtoOptions(jlong base_options_handle, + bool l2_normalize, + bool quantize, + int index_descriptor, + int max_results) { + TextSearcherOptions proto_options; + + if (base_options_handle != kInvalidPointer) { + // proto_options will free the previous base_options and set the new one. + proto_options.set_allocated_base_options( + reinterpret_cast<BaseOptions*>(base_options_handle)); + } + + auto embedding_options = proto_options.mutable_embedding_options(); + embedding_options->set_l2_normalize(l2_normalize); + embedding_options->set_quantize(quantize); + + auto search_options = proto_options.mutable_search_options(); + if (index_descriptor > 0) { + search_options->mutable_index_file() + ->mutable_file_descriptor_meta() + ->set_fd(index_descriptor); + } + search_options->set_max_results(max_results); + + return proto_options; +} + +jlong CreateTextSearcherFromOptions(JNIEnv* env, + const TextSearcherOptions& options) { + StatusOr<std::unique_ptr<TextSearcher>> text_searcher_or = + TextSearcher::CreateFromOptions(options, + tflite::task::CreateOpResolver()); + if (text_searcher_or.ok()) { + return reinterpret_cast<jlong>(text_searcher_or->release()); + } else { + ThrowException( + env, + GetExceptionClassNameForStatusCode(text_searcher_or.status().code()), + "Error occurred when initializing TextSearcher: %s", + text_searcher_or.status().message().data()); + return kInvalidPointer; + } +} + +jobject ConvertToSearchResults(JNIEnv* env, const SearchResult& results) { + // jclass and factory create of NearestNeighbor. + jclass nearest_neighbor_class = + env->FindClass("org/tensorflow/lite/task/processor/NearestNeighbor"); + jmethodID nearest_neighbor_create = + env->GetStaticMethodID(nearest_neighbor_class, "create", + "([BF)Lorg/tensorflow/lite/" + "task/processor/NearestNeighbor;"); + + return ConvertVectorToArrayList( + env, results.nearest_neighbors().begin(), + results.nearest_neighbors().end(), + [env, nearest_neighbor_class, + nearest_neighbor_create](const NearestNeighbor& neightbor) { + jbyteArray jmetadata = CreateByteArray( + env, reinterpret_cast<const jbyte*>(neightbor.metadata().data()), + neightbor.metadata().size()); + jobject jnearest_neighbor = env->CallStaticObjectMethod( + nearest_neighbor_class, nearest_neighbor_create, jmetadata, + neightbor.distance()); + env->DeleteLocalRef(jmetadata); + return jnearest_neighbor; + }); +} + +} // namespace + +extern "C" JNIEXPORT void JNICALL +Java_org_tensorflow_lite_task_text_searcher_TextSearcher_deinitJni( + JNIEnv* env, + jobject thiz, + jlong native_handle) { + delete reinterpret_cast<TextSearcher*>(native_handle); +} + +// Creates an TextSearcher instance from the model file descriptor. +// file_descriptor_length and file_descriptor_offset are optional. Non-positive +// values will be ignored. +extern "C" JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_task_text_searcher_TextSearcher_initJniWithModelFdAndOptions( + JNIEnv* env, + jclass thiz, + jint model_descriptor, + jlong model_descriptor_length, + jlong model_descriptor_offset, + jlong base_options_handle, + bool l2_normalize, + bool quantize, + jint index_descriptor, + int max_results) { + TextSearcherOptions proto_options = + ConvertToProtoOptions(base_options_handle, l2_normalize, quantize, + index_descriptor, max_results); + auto file_descriptor_meta = proto_options.mutable_base_options() + ->mutable_model_file() + ->mutable_file_descriptor_meta(); + file_descriptor_meta->set_fd(model_descriptor); + if (model_descriptor_length > 0) { + file_descriptor_meta->set_length(model_descriptor_length); + } + if (model_descriptor_offset > 0) { + file_descriptor_meta->set_offset(model_descriptor_offset); + } + + return CreateTextSearcherFromOptions(env, proto_options); +} + +extern "C" JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_task_text_searcher_TextSearcher_initJniWithByteBuffer( + JNIEnv* env, + jclass thiz, + jobject model_buffer, + jlong base_options_handle, + bool l2_normalize, + bool quantize, + jlong index_descriptor, + int max_results) { + TextSearcherOptions proto_options = + ConvertToProtoOptions(base_options_handle, l2_normalize, quantize, + index_descriptor, max_results); + proto_options.mutable_base_options()->mutable_model_file()->set_file_content( + static_cast<char*>(env->GetDirectBufferAddress(model_buffer)), + static_cast<size_t>(env->GetDirectBufferCapacity(model_buffer))); + + return CreateTextSearcherFromOptions(env, proto_options); +} + +extern "C" JNIEXPORT jobject JNICALL +Java_org_tensorflow_lite_task_text_searcher_TextSearcher_searchNative( + JNIEnv* env, + jclass thiz, + jlong native_handle, + jstring text) { + auto* searcher = reinterpret_cast<TextSearcher*>(native_handle); + auto results_or = searcher->Search(JStringToString(env, text)); + + if (results_or.ok()) { + return ConvertToSearchResults(env, results_or.value()); + } else { + ThrowException( + env, GetExceptionClassNameForStatusCode(results_or.status().code()), + "Error occurred when searching the input text: %s", + results_or.status().message().data()); + return nullptr; + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/searcher/universal_sentence_encoder_qa_op_register.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/searcher/universal_sentence_encoder_qa_op_register.cc new file mode 100644 index 0000000..440af02 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/text/searcher/universal_sentence_encoder_qa_op_register.cc
@@ -0,0 +1,26 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/examples/task/text/desktop/universal_sentence_encoder_qa_op_resolver.h" + +namespace tflite { +namespace task { +// Provides a custom OpResolver for TextSearcher Java API. +std::unique_ptr<OpResolver> CreateOpResolver() { + return tflite::task::text::CreateQACustomOpResolver(); +} + +} // namespace task +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/BUILD index 6f784145..9d1e18c 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/BUILD
@@ -41,6 +41,7 @@ "//tensorflow_lite_support/java/src/native/task/vision/classifier:image_classifier_jni.cc", "//tensorflow_lite_support/java/src/native/task/vision/core:base_vision_task_api_jni.cc", "//tensorflow_lite_support/java/src/native/task/vision/detector:object_detector_jni.cc", + "//tensorflow_lite_support/java/src/native/task/vision/searcher:image_searcher_jni.cc", "//tensorflow_lite_support/java/src/native/task/vision/segmenter:image_segmenter_jni.cc", ], linkscript = "//tensorflow_lite_support/java:default_version_script.lds", @@ -51,15 +52,18 @@ "//tensorflow_lite_support/cc/task/vision:object_detector", "//tensorflow_lite_support/cc/utils:jni_utils", "//tensorflow_lite_support/java/src/native/task/vision:jni_utils", + "//tensorflow_lite_support/cc/task/vision:image_searcher", ], deps = [ "//tensorflow_lite_support/cc/port:statusor", "//tensorflow_lite_support/cc/task/core/proto:base_options_proto_inc", + "//tensorflow_lite_support/cc/task/processor/proto:search_result_cc_proto", "//tensorflow_lite_support/cc/task/vision/core:frame_buffer", "//tensorflow_lite_support/cc/task/vision/proto:bounding_box_proto_inc", "//tensorflow_lite_support/cc/task/vision/proto:classifications_proto_inc", "//tensorflow_lite_support/cc/task/vision/proto:detections_proto_inc", "//tensorflow_lite_support/cc/task/vision/proto:image_classifier_options_proto_inc", + "//tensorflow_lite_support/cc/task/vision/proto:image_searcher_options_cc_proto", "//tensorflow_lite_support/cc/task/vision/proto:image_segmenter_options_proto_inc", "//tensorflow_lite_support/cc/task/vision/proto:object_detector_options_proto_inc", "//tensorflow_lite_support/cc/task/vision/proto:segmentations_proto_inc",
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/searcher/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/searcher/BUILD new file mode 100644 index 0000000..c09ee42 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/searcher/BUILD
@@ -0,0 +1,49 @@ +load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite", "jni_binary_with_tflite") + +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +) + +exports_files(["image_searcher_jni.cc"]) + +cc_library_with_tflite( + name = "image_searcher_native", + tflite_jni_binaries = [ + ":libtask_vision_jni.so", + ], +) + +jni_binary_with_tflite( + name = "libtask_vision_jni.so", + linkscript = "//tensorflow_lite_support/java:default_version_script.lds", + tflite_deps = [ + ":native_without_resolver", + "//tensorflow_lite_support/java/src/native/task/core:builtin_op_resolver", + ], +) + +# Shared native logic for ImageSearcher. Combine this target and customized +# version of op_resolver to build customized image_searcher_native target. +cc_library_with_tflite( + name = "native_without_resolver", + srcs = [ + "image_searcher_jni.cc", + "//tensorflow_lite_support/java/src/native/task/core:task_jni_utils.cc", + "//tensorflow_lite_support/java/src/native/task/vision/core:base_vision_task_api_jni.cc", + ], + tflite_deps = [ + "//tensorflow_lite_support/cc/task/vision:image_searcher", + "//tensorflow_lite_support/cc/utils:jni_utils", + "//tensorflow_lite_support/java/src/native/task/vision:jni_utils", + ], + deps = [ + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/task/core/proto:base_options_proto_inc", + "//tensorflow_lite_support/cc/task/processor/proto:search_result_cc_proto", + "//tensorflow_lite_support/cc/task/vision/core:frame_buffer", + "//tensorflow_lite_support/cc/task/vision/proto:image_searcher_options_cc_proto", + "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_common_utils", + "//tensorflow_lite_support/java/jni", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/searcher/image_searcher_jni.cc b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/searcher/image_searcher_jni.cc new file mode 100644 index 0000000..84cad5db --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/java/src/native/task/vision/searcher/image_searcher_jni.cc
@@ -0,0 +1,216 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <jni.h> + +#include <memory> + +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/core/proto/base_options_proto_inc.h" +#include "tensorflow_lite_support/cc/task/processor/proto/search_result.pb.h" +#include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h" +#include "tensorflow_lite_support/cc/task/vision/image_searcher.h" +#include "tensorflow_lite_support/cc/task/vision/proto/image_searcher_options.pb.h" +#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h" +#include "tensorflow_lite_support/cc/utils/jni_utils.h" +#include "tensorflow_lite_support/java/src/native/task/vision/jni_utils.h" + +namespace tflite { +namespace task { +// To be provided by a link-time library +extern std::unique_ptr<OpResolver> CreateOpResolver(); + +} // namespace task +} // namespace tflite + +namespace { + +using ::tflite::support::StatusOr; +using ::tflite::support::utils::ConvertVectorToArrayList; +using ::tflite::support::utils::CreateByteArray; +using ::tflite::support::utils::GetExceptionClassNameForStatusCode; +using ::tflite::support::utils::kInvalidPointer; +using ::tflite::support::utils::ThrowException; +using ::tflite::task::core::BaseOptions; +using ::tflite::task::processor::NearestNeighbor; +using ::tflite::task::processor::SearchResult; +using ::tflite::task::vision::BoundingBox; +using ::tflite::task::vision::FrameBuffer; +using ::tflite::task::vision::ImageSearcher; +using ::tflite::task::vision::ImageSearcherOptions; + +// Creates an ImageSearcherOptions proto based on the Java class. +ImageSearcherOptions ConvertToProtoOptions(jlong base_options_handle, + bool l2_normalize, + bool quantize, + int index_descriptor, + int max_results) { + ImageSearcherOptions proto_options; + + if (base_options_handle != kInvalidPointer) { + // proto_options will free the previous base_options and set the new one. + proto_options.set_allocated_base_options( + reinterpret_cast<BaseOptions*>(base_options_handle)); + } + + auto embedding_options = proto_options.mutable_embedding_options(); + embedding_options->set_l2_normalize(l2_normalize); + embedding_options->set_quantize(quantize); + + auto search_options = proto_options.mutable_search_options(); + if (index_descriptor > 0) { + search_options->mutable_index_file() + ->mutable_file_descriptor_meta() + ->set_fd(index_descriptor); + } + search_options->set_max_results(max_results); + + return proto_options; +} + +jlong CreateImageSearcherFromOptions(JNIEnv* env, + const ImageSearcherOptions& options) { + StatusOr<std::unique_ptr<ImageSearcher>> image_searcher_or = + ImageSearcher::CreateFromOptions(options, + tflite::task::CreateOpResolver()); + if (image_searcher_or.ok()) { + return reinterpret_cast<jlong>(image_searcher_or->release()); + } else { + ThrowException( + env, + GetExceptionClassNameForStatusCode(image_searcher_or.status().code()), + "Error occurred when initializing ImageSearcher: %s", + image_searcher_or.status().message().data()); + return kInvalidPointer; + } +} + +jobject ConvertToSearchResults(JNIEnv* env, const SearchResult& results) { + // jclass and factory create of NearestNeighbor. + jclass nearest_neighbor_class = + env->FindClass("org/tensorflow/lite/task/processor/NearestNeighbor"); + jmethodID nearest_neighbor_create = + env->GetStaticMethodID(nearest_neighbor_class, "create", + "([BF)Lorg/tensorflow/lite/" + "task/processor/NearestNeighbor;"); + + return ConvertVectorToArrayList( + env, results.nearest_neighbors().begin(), + results.nearest_neighbors().end(), + [env, nearest_neighbor_class, + nearest_neighbor_create](const NearestNeighbor& neightbor) { + jbyteArray jmetadata = CreateByteArray( + env, reinterpret_cast<const jbyte*>(neightbor.metadata().data()), + neightbor.metadata().size()); + jobject jnearest_neighbor = env->CallStaticObjectMethod( + nearest_neighbor_class, nearest_neighbor_create, jmetadata, + neightbor.distance()); + env->DeleteLocalRef(jmetadata); + return jnearest_neighbor; + }); +} + +} // namespace + +extern "C" JNIEXPORT void JNICALL +Java_org_tensorflow_lite_task_vision_searcher_ImageSearcher_deinitJni( + JNIEnv* env, + jobject thiz, + jlong native_handle) { + delete reinterpret_cast<ImageSearcher*>(native_handle); +} + +// Creates an ImageSearcher instance from the model file descriptor. +// file_descriptor_length and file_descriptor_offset are optional. Non-positive +// values will be ignored. +extern "C" JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_task_vision_searcher_ImageSearcher_initJniWithModelFdAndOptions( + JNIEnv* env, + jclass thiz, + jint model_descriptor, + jlong model_descriptor_length, + jlong model_descriptor_offset, + jlong base_options_handle, + bool l2_normalize, + bool quantize, + jint index_descriptor, + int max_results) { + ImageSearcherOptions proto_options = + ConvertToProtoOptions(base_options_handle, l2_normalize, quantize, + index_descriptor, max_results); + auto file_descriptor_meta = proto_options.mutable_base_options() + ->mutable_model_file() + ->mutable_file_descriptor_meta(); + file_descriptor_meta->set_fd(model_descriptor); + if (model_descriptor_length > 0) { + file_descriptor_meta->set_length(model_descriptor_length); + } + if (model_descriptor_offset > 0) { + file_descriptor_meta->set_offset(model_descriptor_offset); + } + + return CreateImageSearcherFromOptions(env, proto_options); +} + +extern "C" JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_task_vision_searcher_ImageSearcher_initJniWithByteBuffer( + JNIEnv* env, + jclass thiz, + jobject model_buffer, + jlong base_options_handle, + bool l2_normalize, + bool quantize, + jlong index_descriptor, + int max_results) { + ImageSearcherOptions proto_options = + ConvertToProtoOptions(base_options_handle, l2_normalize, quantize, + index_descriptor, max_results); + proto_options.mutable_base_options()->mutable_model_file()->set_file_content( + static_cast<char*>(env->GetDirectBufferAddress(model_buffer)), + static_cast<size_t>(env->GetDirectBufferCapacity(model_buffer))); + + return CreateImageSearcherFromOptions(env, proto_options); +} + +extern "C" JNIEXPORT jobject JNICALL +Java_org_tensorflow_lite_task_vision_searcher_ImageSearcher_searchNative( + JNIEnv* env, + jclass thiz, + jlong native_handle, + jlong frame_buffer_handle, + jintArray jroi) { + auto* searcher = reinterpret_cast<ImageSearcher*>(native_handle); + // frame_buffer will be deleted after inference is done in + // base_vision_api_jni.cc. + auto* frame_buffer = reinterpret_cast<FrameBuffer*>(frame_buffer_handle); + + int* roi_array = env->GetIntArrayElements(jroi, 0); + BoundingBox roi; + roi.set_origin_x(roi_array[0]); + roi.set_origin_y(roi_array[1]); + roi.set_width(roi_array[2]); + roi.set_height(roi_array[3]); + env->ReleaseIntArrayElements(jroi, roi_array, 0); + + auto results_or = searcher->Search(*frame_buffer, roi); + if (results_or.ok()) { + return ConvertToSearchResults(env, results_or.value()); + } else { + ThrowException( + env, GetExceptionClassNameForStatusCode(results_or.status().code()), + "Error occurred when searching the image: %s", + results_or.status().message().data()); + return nullptr; + } +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/BUILD index 5e5c25b..6d77f69 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/BUILD
@@ -21,14 +21,15 @@ "//tensorflow_lite_support/cc/port:status_macros", "//tensorflow_lite_support/cc/port:statusor", "//tensorflow_lite_support/metadata:metadata_schema_cc", + "//tensorflow_lite_support/metadata/cc/utils:zip_readonly_mem_file", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@flatbuffers", - "@org_libzip//:zip", "@org_tensorflow//tensorflow/lite/schema:schema_fbs", + "@zlib//:zlib_minizip", ], ) @@ -59,7 +60,7 @@ "//tensorflow_lite_support/cc/port:status_macros", "//tensorflow_lite_support/cc/port:statusor", "//tensorflow_lite_support/metadata:metadata_schema_cc", - "//tensorflow_lite_support/metadata/cc/utils:zip_mem_file", + "//tensorflow_lite_support/metadata/cc/utils:zip_writable_mem_file", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@flatbuffers//:runtime_cc",
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.cc b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.cc index 3aae0aa..2a723387 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.cc +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.cc
@@ -15,16 +15,19 @@ #include "tensorflow_lite_support/metadata/cc/metadata_extractor.h" -#include <functional> +#include <string> -#include "absl/memory/memory.h" // from @com_google_absl -#include "absl/status/status.h" // from @com_google_absl -#include "absl/strings/str_format.h" // from @com_google_absl +#include "absl/memory/memory.h" // from @com_google_absl +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/str_format.h" // from @com_google_absl +#include "absl/strings/string_view.h" // from @com_google_absl +#include "contrib/minizip/ioapi.h" +#include "contrib/minizip/unzip.h" #include "flatbuffers/flatbuffers.h" // from @flatbuffers -#include "lib/zip.h" // from @org_libzip #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow_lite_support/cc/common.h" #include "tensorflow_lite_support/cc/port/status_macros.h" +#include "tensorflow_lite_support/metadata/cc/utils/zip_readonly_mem_file.h" #include "tensorflow_lite_support/metadata/metadata_schema_generated.h" namespace tflite { @@ -40,27 +43,6 @@ using ::tflite::support::CreateStatusWithPayload; using ::tflite::support::TfLiteSupportStatus; -// Helper class that takes a callback function, and invokes it in its -// destructor. -class SimpleCleanUp { - public: - explicit SimpleCleanUp(std::function<void()> callback) - : callback_(std::move(callback)) {} - - ~SimpleCleanUp() { - if (callback_ != nullptr) - callback_(); - } - - // Use `std::move(simple_cleanup).Cancel()` to prevent the callback from ever - // executing at all. Once a SimpleCleanUp object has been `std::move(...)`-ed, - // it may not be read from again. - void Cancel() && { callback_ = nullptr; } - - private: - std::function<void()> callback_; -}; - // Util to get item from src_vector specified by index. template <typename T> const T* GetItemFromVector( @@ -71,6 +53,70 @@ } return src_vector->Get(index); } + +// Wrapper function around calls to unzip to avoid repeating conversion logic +// from error code to Status. +absl::Status UnzipErrorToStatus(int error) { + if (error != UNZ_OK) { + return CreateStatusWithPayload( + StatusCode::kUnknown, "Unable to read associated file in zip archive.", + TfLiteSupportStatus::kMetadataAssociatedFileZipError); + } + return absl::OkStatus(); +} + +// Stores a file name, position in zip buffer and size. +struct ZipFileInfo { + std::string name; + ZPOS64_T position; + ZPOS64_T size; +}; + +// Returns the ZipFileInfo corresponding to the current file in the provided +// unzFile object. +absl::StatusOr<ZipFileInfo> GetCurrentZipFileInfo(const unzFile& zf) { + // Open file in raw mode, as data is expected to be uncompressed. + int method; + RETURN_IF_ERROR(UnzipErrorToStatus( + unzOpenCurrentFile2(zf, &method, /*level=*/nullptr, /*raw=*/1))); + if (method != Z_NO_COMPRESSION) { + return CreateStatusWithPayload( + StatusCode::kUnknown, "Expected uncompressed zip archive.", + TfLiteSupportStatus::kMetadataAssociatedFileZipError); + } + + // Get file info a first time to get filename size. + unz_file_info64 file_info; + RETURN_IF_ERROR(UnzipErrorToStatus(unzGetCurrentFileInfo64( + zf, &file_info, /*szFileName=*/nullptr, /*szFileNameBufferSize=*/0, + /*extraField=*/nullptr, /*extraFieldBufferSize=*/0, + /*szComment=*/nullptr, /*szCommentBufferSize=*/0))); + + // Second call to get file name. + auto file_name_size = file_info.size_filename; + char* c_file_name = (char*)malloc(file_name_size); + RETURN_IF_ERROR(UnzipErrorToStatus(unzGetCurrentFileInfo64( + zf, &file_info, c_file_name, file_name_size, + /*extraField=*/nullptr, /*extraFieldBufferSize=*/0, + /*szComment=*/nullptr, /*szCommentBufferSize=*/0))); + std::string file_name = std::string(c_file_name, file_name_size); + free(c_file_name); + + // Get position in file. + auto position = unzGetCurrentFileZStreamPos64(zf); + if (position == 0) { + return CreateStatusWithPayload( + StatusCode::kUnknown, "Unable to read file in zip archive.", + TfLiteSupportStatus::kMetadataAssociatedFileZipError); + } + ZipFileInfo result = {.name = file_name, + .position = position, + .size = file_info.uncompressed_size}; + + // Close file and return. + RETURN_IF_ERROR(UnzipErrorToStatus(unzCloseCurrentFile(zf))); + return result; +} } // namespace /* static */ @@ -193,71 +239,45 @@ absl::Status ModelMetadataExtractor::ExtractAssociatedFiles( const char* buffer_data, size_t buffer_size) { - // Setup libzip error reporting. - zip_error_t error; - zip_error_init(&error); - auto zip_error_cleanup = SimpleCleanUp([&error] { zip_error_fini(&error); }); - - // Initialize zip source. - zip_source_t* src = - zip_source_buffer_create(buffer_data, buffer_size, /*freep=*/0, &error); - if (src == nullptr) { - return CreateStatusWithPayload( - StatusCode::kUnknown, - absl::StrFormat("Can't create zip source from model buffer: %s", - zip_error_strerror(&error)), - TfLiteSupportStatus::kMetadataAssociatedFileZipError); - } - auto zip_source_cleanup = SimpleCleanUp([src] { zip_source_free(src); }); - - // Try opening zip source. - zip* zip_archive = zip_open_from_source(src, /*flags=*/0, &error); - if (zip_archive == nullptr) { + // Create in-memory read-only zip file. + ZipReadOnlyMemFile mem_file = ZipReadOnlyMemFile(buffer_data, buffer_size); + // Open zip. + unzFile zf = unzOpen2_64(/*path=*/nullptr, &mem_file.GetFileFunc64Def()); + if (zf == nullptr) { // It's OK if it fails: this means there are no associated files with this // model. return absl::OkStatus(); } - auto zip_archive_cleanup = - SimpleCleanUp([zip_archive] { zip_close(zip_archive); }); - // As per the documentation [1] for zip_source_free, it should not be called - // after a successful call to zip_open_from_source. - // - // [1]: https://libzip.org/documentation/zip_source_free.html - std::move(zip_source_cleanup).Cancel(); + // Get number of files. + unz_global_info global_info; + if (unzGetGlobalInfo(zf, &global_info) != UNZ_OK) { + return CreateStatusWithPayload( + StatusCode::kUnknown, "Unable to get zip archive info.", + TfLiteSupportStatus::kMetadataAssociatedFileZipError); + } - const int num_files = zip_get_num_entries(zip_archive, /*flags=*/0); - for (int index = 0; index < num_files; ++index) { - // Get file stats. - struct zip_stat zip_file_stat; - zip_stat_init(&zip_file_stat); - zip_stat_index(zip_archive, index, /*flags=*/0, &zip_file_stat); - absl::string_view filename = zip_file_stat.name; - const auto unzip_filesize = zip_file_stat.size; - - // Open file. - zip_file* zip_file = zip_fopen_index(zip_archive, index, /*flags=*/0); - if (zip_file == nullptr) { + // Browse through files in archive. + if (global_info.number_entry > 0) { + int error = unzGoToFirstFile(zf); + while (error == UNZ_OK) { + ASSIGN_OR_RETURN(auto zip_file_info, GetCurrentZipFileInfo(zf)); + // Store result in map. + associated_files_[zip_file_info.name] = absl::string_view( + buffer_data + zip_file_info.position, zip_file_info.size); + error = unzGoToNextFile(zf); + } + if (error != UNZ_END_OF_LIST_OF_FILE) { return CreateStatusWithPayload( StatusCode::kUnknown, - absl::StrFormat("Unable to open associated file with name: %s", - zip_file_stat.name), + "Unable to read associated file in zip archive.", TfLiteSupportStatus::kMetadataAssociatedFileZipError); } - auto zip_file_cleanup = SimpleCleanUp([zip_file] { zip_fclose(zip_file); }); - - // Unzip file. - char* unzip_buffer = new char[unzip_filesize]; - auto unzip_buffer_cleanup = - SimpleCleanUp([unzip_buffer] { delete[] unzip_buffer; }); - if (zip_fread(zip_file, unzip_buffer, unzip_filesize) != unzip_filesize) { - return CreateStatusWithPayload( - StatusCode::kUnknown, - absl::StrFormat("Unzipping failed for file: %s.", filename), - TfLiteSupportStatus::kMetadataAssociatedFileZipError); - } - - // Copy file contents in map. - associated_files_[filename] = std::string(unzip_buffer, unzip_filesize); + } + // Close zip. + if (unzClose(zf) != UNZ_OK) { + return CreateStatusWithPayload( + StatusCode::kUnknown, "Unable to close zip archive.", + TfLiteSupportStatus::kMetadataAssociatedFileZipError); } return absl::OkStatus(); }
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.h b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.h index dc9a992a..007919d 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.h
@@ -146,9 +146,9 @@ // Pointer to the extracted ModelMetadata, if any. const tflite::ModelMetadata* model_metadata_{nullptr}; // The files associated with the ModelMetadata, as a map with the filename - // (corresponding to a basename, e.g. "labels.txt") as key and the file - // contents as value. - absl::flat_hash_map<std::string, std::string> associated_files_; + // (corresponding to a basename, e.g. "labels.txt") as key and a pointer to + // the file contents as value. + absl::flat_hash_map<std::string, absl::string_view> associated_files_; }; } // namespace metadata
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_populator.cc b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_populator.cc new file mode 100644 index 0000000..299ade3 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_populator.cc
@@ -0,0 +1,153 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/metadata/cc/metadata_populator.h" + +#include <cstdlib> +#include <cstring> +#include <functional> + +#include "contrib/minizip/ioapi.h" +#include "contrib/minizip/zip.h" +#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow_lite_support/cc/common.h" +#include "tensorflow_lite_support/cc/port/status_macros.h" +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/metadata/cc/utils/zip_writable_mem_file.h" +#include "tensorflow_lite_support/metadata/metadata_schema_generated.h" + +namespace tflite { +namespace metadata { + +namespace { +constexpr char kMetadataBufferName[] = "TFLITE_METADATA"; + +using ::absl::StatusCode; +using ::tflite::support::CreateStatusWithPayload; +using ::tflite::support::TfLiteSupportStatus; + +} // namespace + +ModelMetadataPopulator::ModelMetadataPopulator(const tflite::Model& model) { + model.UnPackTo(&model_t_); +} + +/* static */ +tflite::support::StatusOr<std::unique_ptr<ModelMetadataPopulator>> +ModelMetadataPopulator::CreateFromModelBuffer(const char* buffer_data, + size_t buffer_size) { + // Rely on the simplest, base flatbuffers verifier. Here is not the place to + // e.g. use an OpResolver: we just want to make sure the buffer is valid to + // access the metadata. + flatbuffers::Verifier verifier = flatbuffers::Verifier( + reinterpret_cast<const uint8_t*>(buffer_data), buffer_size); + if (!tflite::VerifyModelBuffer(verifier)) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + "The model is not a valid FlatBuffer buffer.", + TfLiteSupportStatus::kInvalidFlatBufferError); + } + // Use absl::WrapUnique() to call private constructor: + // https://abseil.io/tips/126. + return absl::WrapUnique( + new ModelMetadataPopulator(*tflite::GetModel(buffer_data))); +} + +void ModelMetadataPopulator::LoadMetadata(const char* metadata_buffer_data, + size_t metadata_buffer_size) { + // Pack the model metadata in a buffer. + auto model_metadata_buffer = std::make_unique<tflite::BufferT>(); + model_metadata_buffer->data = {metadata_buffer_data, + metadata_buffer_data + metadata_buffer_size}; + // Check if the model already has metadata. If so, just override the buffer + // and exit. + for (const auto& metadata_t : model_t_.metadata) { + if (metadata_t->name == kMetadataBufferName) { + model_t_.buffers[metadata_t->buffer] = std::move(model_metadata_buffer); + return; + } + } + // Model doesn't already have metadata: add metadata buffer and pointer to the + // buffer in the model metadata section. + model_t_.buffers.push_back(std::move(model_metadata_buffer)); + auto metadata_t = std::make_unique<tflite::MetadataT>(); + metadata_t->name = kMetadataBufferName; + metadata_t->buffer = model_t_.buffers.size() - 1; + model_t_.metadata.push_back(std::move(metadata_t)); +} + +void ModelMetadataPopulator::LoadAssociatedFiles( + const absl::flat_hash_map<std::string, std::string>& associated_files) { + associated_files_ = associated_files; +} + +tflite::support::StatusOr<std::string> +ModelMetadataPopulator::AppendAssociatedFiles(const char* model_buffer_data, + size_t model_buffer_size) { + // Create in-memory writable zip file. + ZipWritableMemFile mem_file = + ZipWritableMemFile(model_buffer_data, model_buffer_size); + // Open zip. + zipFile zf = + zipOpen2_64(/*pathname=*/nullptr, APPEND_STATUS_CREATEAFTER, + /*globalcomment=*/nullptr, &mem_file.GetFileFunc64Def()); + if (zf == nullptr) { + return CreateStatusWithPayload( + StatusCode::kUnknown, "Unable to open zip archive", + TfLiteSupportStatus::kMetadataAssociatedFileZipError); + } + // Write associated files. + for (const auto& [name, contents] : associated_files_) { + if ((zipOpenNewFileInZip64(zf, name.c_str(), + /*zipfi=*/nullptr, + /*extrafield_local=*/nullptr, + /*size_extrafield_local=*/0, + /*extrafield_global=*/nullptr, + /*size_extrafield_global=*/0, + /*comment=*/nullptr, + /*method=*/0, + /*level=*/Z_DEFAULT_COMPRESSION, + /*zip64=*/0) != ZIP_OK) || + (zipWriteInFileInZip(zf, contents.data(), contents.length()) != + ZIP_OK) || + (zipCloseFileInZip(zf) != ZIP_OK)) { + return CreateStatusWithPayload( + StatusCode::kUnknown, "Unable to write file to zip archive", + TfLiteSupportStatus::kMetadataAssociatedFileZipError); + } + } + // Close zip. + if (zipClose(zf, /*global_comment=*/nullptr) != ZIP_OK) { + return CreateStatusWithPayload( + StatusCode::kUnknown, "Unable to close zip archive", + TfLiteSupportStatus::kMetadataAssociatedFileZipError); + } + // Return as a string. + return std::string(mem_file.GetFileContent()); +} + +tflite::support::StatusOr<std::string> ModelMetadataPopulator::Populate() { + // Build model. + flatbuffers::FlatBufferBuilder model_fbb; + model_fbb.Finish(tflite::Model::Pack(model_fbb, &model_t_), + tflite::ModelIdentifier()); + return AppendAssociatedFiles( + reinterpret_cast<char*>(model_fbb.GetBufferPointer()), + model_fbb.GetSize()); +} + +} // namespace metadata +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_populator.h b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_populator.h index 9037f58..4410f84 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_populator.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_populator.h
@@ -31,6 +31,8 @@ // Provides an interface to pack TFLite ModelMetadata [1] and corresponding // associated files into a TFLite FlatBuffer. // +// This class is NOT thread-safe. +// // [1]: https://www.tensorflow.org/lite/convert/metadata class ModelMetadataPopulator { public:
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/BUILD index ced8644..e2146b3 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/BUILD
@@ -6,9 +6,19 @@ ) cc_library( - name = "zip_mem_file", - srcs = ["zip_mem_file.cc"], - hdrs = ["zip_mem_file.h"], + name = "zip_writable_mem_file", + srcs = ["zip_writable_mem_file.cc"], + hdrs = ["zip_writable_mem_file.h"], + deps = [ + "@com_google_absl//absl/strings", + "@zlib//:zlib_minizip", + ], +) + +cc_library( + name = "zip_readonly_mem_file", + srcs = ["zip_readonly_mem_file.cc"], + hdrs = ["zip_readonly_mem_file.h"], deps = [ "@com_google_absl//absl/strings", "@zlib//:zlib_minizip",
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_readonly_mem_file.cc b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_readonly_mem_file.cc new file mode 100644 index 0000000..392b6b4 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_readonly_mem_file.cc
@@ -0,0 +1,124 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/metadata/cc/utils/zip_readonly_mem_file.h" + +#include <algorithm> +#include <cstdio> + +#include "absl/strings/string_view.h" // from @com_google_absl +#include "contrib/minizip/ioapi.h" + +namespace tflite { +namespace metadata { + +ZipReadOnlyMemFile::ZipReadOnlyMemFile(const char* buffer, size_t size) + : data_(buffer, size), offset_(0) { + zlib_filefunc64_def_.zopen64_file = OpenFile; + zlib_filefunc64_def_.zread_file = ReadFile; + zlib_filefunc64_def_.zwrite_file = WriteFile; + zlib_filefunc64_def_.ztell64_file = TellFile; + zlib_filefunc64_def_.zseek64_file = SeekFile; + zlib_filefunc64_def_.zclose_file = CloseFile; + zlib_filefunc64_def_.zerror_file = ErrorFile; + zlib_filefunc64_def_.opaque = this; +} + +zlib_filefunc64_def& ZipReadOnlyMemFile::GetFileFunc64Def() { + return zlib_filefunc64_def_; +} + +/* static */ +voidpf ZipReadOnlyMemFile::OpenFile(voidpf opaque, + const void* filename, + int mode) { + // Result is never used, but needs to be non-null for `zipOpen2` not to fail. + return opaque; +} + +/* static */ +uLong ZipReadOnlyMemFile::ReadFile(voidpf opaque, + voidpf stream, + void* buf, + uLong size) { + auto* mem_file = static_cast<ZipReadOnlyMemFile*>(opaque); + if (mem_file->offset_ < 0 || mem_file->Size() < mem_file->offset_) { + return 0; + } + if (mem_file->offset_ + size > mem_file->Size()) { + size = mem_file->Size() - mem_file->offset_; + } + memcpy(buf, + static_cast<const char*>(mem_file->data_.data()) + mem_file->offset_, + size); + mem_file->offset_ += size; + return size; +} + +/* static */ +uLong ZipReadOnlyMemFile::WriteFile(voidpf opaque, + voidpf stream, + const void* buf, + uLong size) { + // File is not writable. + return 0; +} + +/* static */ +ZPOS64_T ZipReadOnlyMemFile::TellFile(voidpf opaque, voidpf stream) { + return static_cast<ZipReadOnlyMemFile*>(opaque)->offset_; +} + +/* static */ +long ZipReadOnlyMemFile::SeekFile // NOLINT + (voidpf opaque, voidpf stream, ZPOS64_T offset, int origin) { + auto* mem_file = static_cast<ZipReadOnlyMemFile*>(opaque); + switch (origin) { + case SEEK_SET: + mem_file->offset_ = offset; + return 0; + case SEEK_CUR: + if (mem_file->offset_ + offset < 0 || + mem_file->offset_ + offset > mem_file->Size()) { + return -1; + } + mem_file->offset_ += offset; + return 0; + case SEEK_END: + if (mem_file->Size() - offset < 0 || + mem_file->Size() - offset > mem_file->Size()) { + return -1; + } + mem_file->offset_ = offset + mem_file->Size(); + return 0; + default: + return -1; + } +} + +/* static */ +int ZipReadOnlyMemFile::CloseFile(voidpf opaque, voidpf stream) { + // Nothing to do. + return 0; +} + +/* static */ +int ZipReadOnlyMemFile::ErrorFile(voidpf opaque, voidpf stream) { + // Unused. + return 0; +} + +} // namespace metadata +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_readonly_mem_file.h b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_readonly_mem_file.h new file mode 100644 index 0000000..a1799ff --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_readonly_mem_file.h
@@ -0,0 +1,75 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_METADATA_CC_UTILS_ZIP_MEM_FILE_H_ +#define TENSORFLOW_LITE_SUPPORT_METADATA_CC_UTILS_ZIP_MEM_FILE_H_ + +#include <cstdlib> + +#include "absl/strings/string_view.h" // from @com_google_absl +#include "contrib/minizip/ioapi.h" + +namespace tflite { +namespace metadata { + +// In-memory read-only zip file implementation. +// +// Adapted from [1], with a few key differences: +// * backed by an `absl::string_view` instead of malloc-ed C buffers, +// * supports opening the file for reading through `unzOpen2_64`. +// +// This class is NOT thread-safe. +// +// [1]: +// https://github.com/google/libkml/blob/master/third_party/zlib-1.2.3/contrib/minizip/iomem_simple.c +class ZipReadOnlyMemFile { + public: + // Constructs an in-memory read-only zip file from a buffer. Does not copy or + // take ownership over the provided buffer: the caller is responsible for + // ensuring the buffer outlives this object. + ZipReadOnlyMemFile(const char* buffer, size_t size); + // Provides access to the `zlib_filefunc64_def` implementation for the + // in-memory zip file. + zlib_filefunc64_def& GetFileFunc64Def(); + + private: + // The string view backing the in-memory file. + absl::string_view data_; + // The current offset in the file. + ZPOS64_T offset_; + // The `zlib_filefunc64_def` implementation for this in-memory zip file. + zlib_filefunc64_def zlib_filefunc64_def_; + + // Convenience function to access the current data size. + size_t Size() const { return data_.size(); } + + // The file function implementations used in the `zlib_filefunc64_def`. + static voidpf OpenFile(voidpf opaque, const void* filename, int mode); + static uLong ReadFile(voidpf opaque, voidpf stream, void* buf, uLong size); + static uLong WriteFile(voidpf opaque, + voidpf stream, + const void* buf, + uLong size); + static ZPOS64_T TellFile(voidpf opaque, voidpf stream); + static long SeekFile // NOLINT + (voidpf opaque, voidpf stream, ZPOS64_T offset, int origin); + static int CloseFile(voidpf opaque, voidpf stream); + static int ErrorFile(voidpf opaque, voidpf stream); +}; + +} // namespace metadata +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_METADATA_CC_UTILS_ZIP_MEM_FILE_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_writable_mem_file.cc b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_writable_mem_file.cc new file mode 100644 index 0000000..38ad17a --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_writable_mem_file.cc
@@ -0,0 +1,134 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/metadata/cc/utils/zip_writable_mem_file.h" + +#include <algorithm> +#include <cstdio> + +#include "absl/strings/string_view.h" // from @com_google_absl +#include "contrib/minizip/ioapi.h" + +namespace tflite { +namespace metadata { + +ZipWritableMemFile::ZipWritableMemFile(const char* buffer, size_t size) + : data_(buffer, size), offset_(0) { + zlib_filefunc64_def_.zopen64_file = OpenFile; + zlib_filefunc64_def_.zread_file = ReadFile; + zlib_filefunc64_def_.zwrite_file = WriteFile; + zlib_filefunc64_def_.ztell64_file = TellFile; + zlib_filefunc64_def_.zseek64_file = SeekFile; + zlib_filefunc64_def_.zclose_file = CloseFile; + zlib_filefunc64_def_.zerror_file = ErrorFile; + zlib_filefunc64_def_.opaque = this; +} + +zlib_filefunc64_def& ZipWritableMemFile::GetFileFunc64Def() { + return zlib_filefunc64_def_; +} + +absl::string_view ZipWritableMemFile::GetFileContent() const { + return data_; +} + +/* static */ +voidpf ZipWritableMemFile::OpenFile(voidpf opaque, + const void* filename, + int mode) { + // Result is never used, but needs to be non-null for `zipOpen2` not to fail. + return opaque; +} + +/* static */ +uLong ZipWritableMemFile::ReadFile(voidpf opaque, + voidpf stream, + void* buf, + uLong size) { + auto* mem_file = static_cast<ZipWritableMemFile*>(opaque); + if (mem_file->offset_ < 0 || mem_file->Size() < mem_file->offset_) { + return 0; + } + if (mem_file->offset_ + size > mem_file->Size()) { + size = mem_file->Size() - mem_file->offset_; + } + memcpy(buf, + static_cast<const char*>(mem_file->data_.c_str()) + mem_file->offset_, + size); + mem_file->offset_ += size; + return size; +} + +/* static */ +uLong ZipWritableMemFile::WriteFile(voidpf opaque, + voidpf stream, + const void* buf, + uLong size) { + auto* mem_file = static_cast<ZipWritableMemFile*>(opaque); + if (mem_file->offset_ + size > mem_file->Size()) { + mem_file->data_.resize(mem_file->offset_ + size); + } + mem_file->data_.replace(mem_file->offset_, size, + static_cast<const char*>(buf), size); + mem_file->offset_ += size; + return size; +} + +/* static */ +ZPOS64_T ZipWritableMemFile::TellFile(voidpf opaque, voidpf stream) { + return static_cast<ZipWritableMemFile*>(opaque)->offset_; +} + +/* static */ +long ZipWritableMemFile::SeekFile // NOLINT + (voidpf opaque, voidpf stream, ZPOS64_T offset, int origin) { + auto* mem_file = static_cast<ZipWritableMemFile*>(opaque); + switch (origin) { + case SEEK_SET: + mem_file->offset_ = offset; + return 0; + case SEEK_CUR: + if (mem_file->offset_ + offset < 0 || + mem_file->offset_ + offset > mem_file->Size()) { + return -1; + } + mem_file->offset_ += offset; + return 0; + case SEEK_END: + if (mem_file->Size() - offset < 0 || + mem_file->Size() - offset > mem_file->Size()) { + return -1; + } + mem_file->offset_ = offset + mem_file->Size(); + return 0; + default: + return -1; + } +} + +/* static */ +int ZipWritableMemFile::CloseFile(voidpf opaque, voidpf stream) { + // Nothing to do. + return 0; +} + +/* static */ +int ZipWritableMemFile::ErrorFile(voidpf opaque, voidpf stream) { + // Unused. + return 0; +} + +} // namespace metadata +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_writable_mem_file.h b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_writable_mem_file.h new file mode 100644 index 0000000..30e42fd --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_writable_mem_file.h
@@ -0,0 +1,76 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_METADATA_CC_UTILS_ZIP_MEM_FILE_H_ +#define TENSORFLOW_LITE_SUPPORT_METADATA_CC_UTILS_ZIP_MEM_FILE_H_ + +#include <cstdlib> + +#include "absl/strings/string_view.h" // from @com_google_absl +#include "contrib/minizip/ioapi.h" + +namespace tflite { +namespace metadata { + +// In-memory zip file implementation. +// +// Adapted from [1], with a few key differences: +// * backed by an `std::string` instead of malloc-ed C buffers, +// * supports opening the file for writing through `zipOpen2_64`. +// +// This class is NOT thread-safe. +// +// [1]: +// https://github.com/google/libkml/blob/master/third_party/zlib-1.2.3/contrib/minizip/iomem_simple.c +class ZipWritableMemFile { + public: + // Constructs an in-memory writable zip file from a buffer. The provided + // buffer is copied. + ZipWritableMemFile(const char* buffer, size_t size); + // Provides access to the `zlib_filefunc64_def` implementation for the + // in-memory zip file. + zlib_filefunc64_def& GetFileFunc64Def(); + // Provides access to the file contents. + absl::string_view GetFileContent() const; + + private: + // The string backing the in-memory file. + std::string data_; + // The current offset in the file. + ZPOS64_T offset_; + // The `zlib_filefunc64_def` implementation for this in-memory zip file. + zlib_filefunc64_def zlib_filefunc64_def_; + + // Convenience function to access the current data size. + size_t Size() const { return data_.size(); } + + // The file function implementations used in the `zlib_filefunc64_def`. + static voidpf OpenFile(voidpf opaque, const void* filename, int mode); + static uLong ReadFile(voidpf opaque, voidpf stream, void* buf, uLong size); + static uLong WriteFile(voidpf opaque, + voidpf stream, + const void* buf, + uLong size); + static ZPOS64_T TellFile(voidpf opaque, voidpf stream); + static long SeekFile // NOLINT + (voidpf opaque, voidpf stream, ZPOS64_T offset, int origin); + static int CloseFile(voidpf opaque, voidpf stream); + static int ErrorFile(voidpf opaque, voidpf stream); +}; + +} // namespace metadata +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_METADATA_CC_UTILS_ZIP_MEM_FILE_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/BUILD index 5f000b4..681a443 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/java/BUILD
@@ -33,20 +33,27 @@ name = "tensorflowlite_support_metadata_lib", srcs = METADATA_SRCS, javacopts = ["--release 7"], - resource_jars = [ - "//tensorflow_lite_support/metadata:libmetadata_schema_java.jar", - "//tensorflow_lite_support/metadata:libschema_fbs_java.jar", - ], - # LINT.IfChange(dep) deps = [ "//tensorflow_lite_support/metadata:metadata_schema_java", "//tensorflow_lite_support/metadata:schema_fbs_java", "@org_checkerframework_qual", ], - # LINT.ThenChange(<INTERNAL>/release/build_metadata_pom.sh:dep) ) -alias( +# The target for OSS release, which includes the metadata Java library, the +# metadata schema Java binding, and the TFLite schema Java binding. +java_library( name = "tensorflow-lite-support-metadata-lib", - actual = ":tensorflowlite_support_metadata_lib", + srcs = METADATA_SRCS + [ + "//tensorflow_lite_support/metadata:metadata_schema_java_srcjar", + "//tensorflow_lite_support/metadata:schema_fbs_java_srcjar", + ], + javacopts = ["--release 7"], + # LINT.IfChange(dep) + deps = [ + # Need to be consistent as the deps used in flatbuffer_java_library. + "@flatbuffers//:runtime_java", + "@org_checkerframework_qual", + ], + # LINT.ThenChange(<INTERNAL>/release/build_metadata_pom.sh:dep) )
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/BUILD index 91f7ad68..f771c48 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/BUILD
@@ -22,6 +22,7 @@ srcs_version = "PY3", visibility = ["//visibility:public"], deps = [ + # build rule placeholder: numpy dep, "//tensorflow_lite_support/metadata:metadata_schema_py", "//tensorflow_lite_support/metadata:schema_py", "//tensorflow_lite_support/metadata/cc/python:_pywrap_metadata_version",
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/apis/GMLImage.h b/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/apis/GMLImage.h index 0c49491..18797d8 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/apis/GMLImage.h +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/apis/GMLImage.h
@@ -19,8 +19,8 @@ NS_ASSUME_NONNULL_BEGIN /** Types of image sources. */ -typedef NSInteger GMLImageSourceType NS_TYPED_ENUM - NS_SWIFT_NAME(MLImageSourceType); +typedef NSInteger GMLImageSourceType + NS_TYPED_ENUM NS_SWIFT_NAME(MLImageSourceType); /** Image source is a `UIImage`. */ static const GMLImageSourceType GMLImageSourceTypeImage = 0; /** Image source is a `CVPixelBuffer`. */
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/third_party_licenses/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/third_party_licenses/BUILD index 4b75bc0..7c5f7e6 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/third_party_licenses/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/odml/java/image/third_party_licenses/BUILD
@@ -62,6 +62,9 @@ "JsInterop Annotations": "third_party/java_src/jsinterop_annotations/LICENSE", "TensorFlow Lite Support": "third_party/tensorflow_lite_support/LICENSE", "Kotlin": "third_party/kotlin/kotlin/LICENSE", + "AndroidX collection jvm library": "third_party/java/androidx/collection/jvm/LICENSE", + "AndroidX core ktx library": "third_party/java/androidx/core/ktx/LICENSE", + "Android SDK": "third_party/java/android/android_sdk_linux/LICENSE", }, platform = "android", target = "//tensorflow_lite_support/odml/java/image",
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/opensource/opensource_only.files b/third_party/tflite_support/src/tensorflow_lite_support/opensource/opensource_only.files index 0bce7f2..c98911c 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/opensource/opensource_only.files +++ b/third_party/tflite_support/src/tensorflow_lite_support/opensource/opensource_only.files
@@ -19,7 +19,6 @@ tensorflow_lite_support/third_party/icu.BUILD: tensorflow_lite_support/third_party/leveldb.BUILD: tensorflow_lite_support/third_party/libyuv.BUILD: -tensorflow_lite_support/third_party/libzip.BUILD: tensorflow_lite_support/third_party/pybind11.BUILD: tensorflow_lite_support/third_party/python_runtime/BUILD: tensorflow_lite_support/third_party/six.BUILD: @@ -45,7 +44,10 @@ tensorflow_lite_support/tools/pip_package/metadata_writers.__init__.py: tensorflow_lite_support/tools/pip_package/setup.py: tensorflow_lite_support/tools/pip_package/simple_console_for_windows.py: +tensorflow_lite_support/tools/pip_package/task.__init__.py: +tensorflow_lite_support/tools/pip_package/task_audio.__init__.py: tensorflow_lite_support/tools/pip_package/task_core.__init__.py: tensorflow_lite_support/tools/pip_package/task_processor.__init__.py: +tensorflow_lite_support/tools/pip_package/task_text.__init__.py: tensorflow_lite_support/tools/pip_package/task_vision.__init__.py: tensorflow_lite_support/tools/pip_package/tflite_support.__init__.py:
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/python/task/BUILD new file mode 100644 index 0000000..0988163 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/BUILD
@@ -0,0 +1,6 @@ +package( + default_visibility = [ + "//tensorflow_lite_support:internal", + ], + licenses = ["notice"], # Apache 2.0 +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/audio_classifier.py b/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/audio_classifier.py index fd008bd6..9a8543a 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/audio_classifier.py +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/audio_classifier.py
@@ -83,7 +83,7 @@ RuntimeError: If other types of error occurred. """ classifier = _CppAudioClassifier.create_from_options( - options.base_options, options.classification_options) + options.base_options, options.classification_options.to_pb2()) return cls(options, classifier) def create_input_tensor_audio(self) -> tensor_audio.TensorAudio: @@ -122,8 +122,10 @@ ValueError: If any of the input arguments is invalid. RuntimeError: If failed to run audio classification. """ - return self._classifier.classify( + classification_result = self._classifier.classify( _CppAudioBuffer(audio.buffer, audio.buffer_size, audio.format)) + return classifications_pb2.ClassificationResult.create_from_pb2( + classification_result) @property def required_input_buffer_size(self) -> int: @@ -132,5 +134,9 @@ @property def required_audio_format(self) -> _CppAudioFormat: - """Gets the required audio format for the model.""" + """Gets the required audio format for the model. + + Raises: + RuntimeError: If failed to get the required audio format. + """ return self._classifier.get_required_audio_format()
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/audio_embedder.py b/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/audio_embedder.py index 83479c1..5346c7b 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/audio_embedder.py +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/audio_embedder.py
@@ -81,8 +81,8 @@ `AudioEmbedderOptions` such as missing the model. RuntimeError: If other types of error occurred. """ - embedder = _CppAudioEmbedder.create_from_options(options.base_options, - options.embedding_options) + embedder = _CppAudioEmbedder.create_from_options( + options.base_options, options.embedding_options.to_pb2()) return cls(options, embedder) def create_input_tensor_audio(self) -> tensor_audio.TensorAudio: @@ -119,13 +119,14 @@ ValueError: If any of the input arguments is invalid. RuntimeError: If failed to calculate the embedding vector. """ - return self._embedder.embed( + embedding_result = self._embedder.embed( _CppAudioBuffer(audio.buffer, audio.buffer_size, audio.format)) + return embedding_pb2.EmbeddingResult.create_from_pb2(embedding_result) def cosine_similarity(self, u: embedding_pb2.FeatureVector, v: embedding_pb2.FeatureVector) -> float: """Computes cosine similarity [1] between two feature vectors.""" - return self._embedder.cosine_similarity(u, v) + return self._embedder.cosine_similarity(u.to_pb2(), v.to_pb2()) def get_embedding_dimension(self, output_index: int) -> int: """Gets the dimensionality of the embedding output. @@ -151,5 +152,9 @@ @property def required_audio_format(self) -> _CppAudioFormat: - """Gets the required audio format for the model.""" + """Gets the required audio format for the model. + + Raises: + RuntimeError: If failed to get the required audio format. + """ return self._embedder.get_required_audio_format()
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/core/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/core/BUILD index 2571d2e..4c4df440 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/core/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/core/BUILD
@@ -8,10 +8,6 @@ py_library( name = "audio_record", srcs = ["audio_record.py"], - deps = [ - # build rule placeholder: numpy dep, - # build rule placeholder: sounddevice dep, - ], ) py_library(
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/core/audio_record.py b/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/core/audio_record.py index c9bb0e09..e3b2eb7 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/core/audio_record.py +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/core/audio_record.py
@@ -14,7 +14,17 @@ """A module to record audio in a streaming basis.""" import threading import numpy as np -import sounddevice as sd + +try: +# pylint: disable=g-import-not-at-top + import sounddevice as sd +# pylint: enable=g-import-not-at-top +except OSError as oe: + sd = None + sd_error = oe +except ImportError as ie: + sd = None + sd_error = ie class AudioRecord(object): @@ -31,7 +41,12 @@ Raises: ValueError: if any of the arguments is non-positive. + ImportError: if failed to import `sounddevice`. + OSError: if failed to load `PortAudio`. """ + if sd is None: + raise sd_error + if channels <= 0: raise ValueError('channels must be postive.') if sampling_rate <= 0:
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/core/tensor_audio.py b/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/core/tensor_audio.py index 1ff9ef63..3b3f543 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/core/tensor_audio.py +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/core/tensor_audio.py
@@ -18,14 +18,14 @@ from tensorflow_lite_support.python.task.audio.core import audio_record from tensorflow_lite_support.python.task.audio.core.pybinds import _pywrap_audio_buffer -_CppAudioFormat = _pywrap_audio_buffer.AudioFormat _LoadAudioBufferFromFile = _pywrap_audio_buffer.LoadAudioBufferFromFile +AudioFormat = _pywrap_audio_buffer.AudioFormat class TensorAudio(object): """A wrapper class to store the input audio.""" - def __init__(self, audio_format: _CppAudioFormat, buffer_size: int) -> None: + def __init__(self, audio_format: AudioFormat, buffer_size: int) -> None: """Initializes the `TensorAudio` object. Args: @@ -137,7 +137,7 @@ self._buffer[-shift:, :] = src[offset:offset + size].copy() @property - def format(self) -> _CppAudioFormat: + def format(self) -> AudioFormat: """Gets the audio format of the audio.""" return self._format
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/pybinds/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/pybinds/BUILD index be38598..ee61bbf 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/pybinds/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/pybinds/BUILD
@@ -1,4 +1,4 @@ -load("@org_tensorflow//tensorflow:tensorflow.bzl", "pybind_extension") +load("//tensorflow_lite_support/python/task:build_defs.bzl", "pybind_extension_may_pack_coral") package( default_visibility = [ @@ -7,7 +7,7 @@ licenses = ["notice"], # Apache 2.0 ) -pybind_extension( +pybind_extension_may_pack_coral( name = "_pywrap_audio_embedder", srcs = [ "_pywrap_audio_embedder.cc", @@ -24,7 +24,7 @@ ], ) -pybind_extension( +pybind_extension_may_pack_coral( name = "_pywrap_audio_classifier", srcs = [ "_pywrap_audio_classifier.cc", @@ -33,7 +33,9 @@ deps = [ "//tensorflow_lite_support/cc/task/audio:audio_classifier", "//tensorflow_lite_support/cc/task/audio/core:audio_buffer", + "//tensorflow_lite_support/cc/task/audio/proto:classifications_proto_inc", "//tensorflow_lite_support/cc/task/processor/proto:classification_options_cc_proto", + "//tensorflow_lite_support/cc/task/processor/proto:classifications_cc_proto", "//tensorflow_lite_support/python/task/core/pybinds:task_utils", "@pybind11", "@pybind11_protobuf//pybind11_protobuf:native_proto_caster",
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/pybinds/_pywrap_audio_classifier.cc b/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/pybinds/_pywrap_audio_classifier.cc new file mode 100644 index 0000000..e2054cf --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/pybinds/_pywrap_audio_classifier.cc
@@ -0,0 +1,92 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "pybind11/pybind11.h" +#include "pybind11_protobuf/native_proto_caster.h" // from @pybind11_protobuf +#include "tensorflow_lite_support/cc/task/audio/audio_classifier.h" +#include "tensorflow_lite_support/cc/task/audio/core/audio_buffer.h" +#include "tensorflow_lite_support/cc/task/audio/proto/classifications_proto_inc.h" +#include "tensorflow_lite_support/cc/task/processor/proto/classification_options.pb.h" +#include "tensorflow_lite_support/cc/task/processor/proto/classifications.pb.h" +#include "tensorflow_lite_support/python/task/core/pybinds/task_utils.h" + +namespace tflite { +namespace task { +namespace audio { + +namespace { +namespace py = ::pybind11; +using PythonBaseOptions = ::tflite::python::task::core::BaseOptions; +using CppBaseOptions = ::tflite::task::core::BaseOptions; +} // namespace + +PYBIND11_MODULE(_pywrap_audio_classifier, m) { + // python wrapper for C++ AudioClassifier class which shouldn't be directly + // used by the users. + pybind11_protobuf::ImportNativeProtoCasters(); + + py::class_<AudioClassifier>(m, "AudioClassifier") + .def_static( + "create_from_options", + [](const PythonBaseOptions& base_options, + const processor::ClassificationOptions& classification_options) { + AudioClassifierOptions options; + auto cpp_base_options = + core::convert_to_cpp_base_options(base_options); + options.set_allocated_base_options(cpp_base_options.release()); + + if (classification_options.has_display_names_locale()) { + options.set_display_names_locale( + classification_options.display_names_locale()); + } + if (classification_options.has_max_results()) { + options.set_max_results(classification_options.max_results()); + } + if (classification_options.has_score_threshold()) { + options.set_score_threshold( + classification_options.score_threshold()); + } + options.mutable_class_name_allowlist()->CopyFrom( + classification_options.class_name_allowlist()); + options.mutable_class_name_denylist()->CopyFrom( + classification_options.class_name_denylist()); + + auto classifier = AudioClassifier::CreateFromOptions(options); + return core::get_value(classifier); + }) + .def("classify", + [](AudioClassifier& self, const AudioBuffer& audio_buffer) + -> processor::ClassificationResult { + auto core_classification_result = self.Classify(audio_buffer); + // Convert from core::ClassificationResult to + // processor::ClassificationResult. + processor::ClassificationResult classification_result; + classification_result.ParseFromString( + core::get_value(core_classification_result) + .SerializeAsString()); + return classification_result; + }) + .def("get_required_audio_format", + [](AudioClassifier& self) -> AudioBuffer::AudioFormat { + auto audio_format = self.GetRequiredAudioFormat(); + return core::get_value(audio_format); + }) + .def("get_required_input_buffer_size", + &AudioClassifier::GetRequiredInputBufferSize); +} + +} // namespace audio +} // namespace task +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/pybinds/_pywrap_audio_embedder.cc b/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/pybinds/_pywrap_audio_embedder.cc new file mode 100644 index 0000000..8b1d67d9 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/audio/pybinds/_pywrap_audio_embedder.cc
@@ -0,0 +1,78 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "pybind11/pybind11.h" +#include "pybind11_protobuf/native_proto_caster.h" // from @pybind11_protobuf +#include "tensorflow_lite_support/cc/task/audio/audio_embedder.h" +#include "tensorflow_lite_support/cc/task/audio/core/audio_buffer.h" +#include "tensorflow_lite_support/cc/task/processor/proto/embedding.pb.h" +#include "tensorflow_lite_support/python/task/core/pybinds/task_utils.h" + +namespace tflite { +namespace task { +namespace audio { + +namespace { +namespace py = ::pybind11; +using PythonBaseOptions = ::tflite::python::task::core::BaseOptions; +using CppBaseOptions = ::tflite::task::core::BaseOptions; +} // namespace + +PYBIND11_MODULE(_pywrap_audio_embedder, m) { + // python wrapper for C++ AudioEmbedder class which shouldn't be directly used + // by the users. + pybind11_protobuf::ImportNativeProtoCasters(); + + py::class_<AudioEmbedder>(m, "AudioEmbedder") + .def_static( + "create_from_options", + [](const PythonBaseOptions& base_options, + const processor::EmbeddingOptions& embedding_options) { + AudioEmbedderOptions options; + auto cpp_base_options = + core::convert_to_cpp_base_options(base_options); + + options.set_allocated_base_options(cpp_base_options.release()); + options.add_embedding_options()->CopyFrom(embedding_options); + auto embedder = AudioEmbedder::CreateFromOptions(options); + return core::get_value(embedder); + }) + .def_static("cosine_similarity", + [](const processor::FeatureVector& u, + const processor::FeatureVector& v) -> double { + auto similarity = AudioEmbedder::CosineSimilarity(u, v); + return core::get_value(similarity); + }) + .def("embed", + [](AudioEmbedder& self, + const AudioBuffer& audio_buffer) -> processor::EmbeddingResult { + auto embedding_result = self.Embed(audio_buffer); + return core::get_value(embedding_result); + }) + .def("get_embedding_dimension", &AudioEmbedder::GetEmbeddingDimension) + .def("get_number_of_output_layers", + &AudioEmbedder::GetNumberOfOutputLayers) + .def("get_required_audio_format", + [](AudioEmbedder& self) -> AudioBuffer::AudioFormat { + auto audio_format = self.GetRequiredAudioFormat(); + return core::get_value(audio_format); + }) + .def("get_required_input_buffer_size", + &AudioEmbedder::GetRequiredInputBufferSize); +} + +} // namespace audio +} // namespace task +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/build_defs.bzl b/third_party/tflite_support/src/tensorflow_lite_support/python/task/build_defs.bzl new file mode 100644 index 0000000..86f2206 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/build_defs.bzl
@@ -0,0 +1,31 @@ +""".bzl file for Acceleration allowlisting.""" + +load("@org_tensorflow//tensorflow:tensorflow.bzl", "pybind_extension") + +def pybind_extension_may_pack_coral(name, deps, **kwargs): + """Defines a pybind_extension rule that optionally depends on Coral. + + It pulls in Coral EdgeTPU plugin dependency when passing + `--define darwinn_portable=1` to the build command. + + Args: + name: determines the name used for the generated pybind_extension target. + deps: dependencies that will be unconditionally included in the deps of + the generated pybind_extension targets. + **kwargs: + Additional pybind_extension parameters. + """ + pybind_extension( + name = name, + # Note that `darwinn_portable` is used not only when selecting + # `edgetpu_coral_plugin` here, but also a necessary flag to build + # `edgetpu_coral_plugin`. + deps = deps + select({ + "//tensorflow_lite_support/examples/task:darwinn_portable": [ + "//tensorflow_lite_support/acceleration/configuration:edgetpu_coral_plugin", + ], + "//conditions:default": [ + ], + }), + **kwargs + )
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/core/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/python/task/core/BUILD index 6e12483..36ddd7f4 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/python/task/core/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/core/BUILD
@@ -1,4 +1,13 @@ +# Placeholder for internal Python strict library compatibility macro. + package( default_visibility = ["//tensorflow_lite_support:internal"], licenses = ["notice"], # Apache 2.0 ) + +py_library( + name = "optional_dependencies", + srcs = [ + "optional_dependencies.py", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/core/optional_dependencies.py b/third_party/tflite_support/src/tensorflow_lite_support/python/task/core/optional_dependencies.py new file mode 100644 index 0000000..c002a661 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/core/optional_dependencies.py
@@ -0,0 +1,27 @@ +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TFLite Support's common but optional dependencies.""" + +# TensorFlow isn't a dependency of tflite-support pip package. It's only +# required in the API docgen pipeline so we'll ignore it if tensorflow is not +# installed. +# pylint: disable=g-import-not-at-top +try: + from tensorflow.tools import docs + doc_controls = docs.doc_controls +except ModuleNotFoundError: + # Replace the real doc_controls with MagicMock to ignore all calls to it. + from unittest import mock + doc_controls = mock.MagicMock() +# pylint: enable=g-import-not-at-top
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/processor/proto/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/python/task/processor/proto/BUILD index adcd039..ab11e23 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/python/task/processor/proto/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/processor/proto/BUILD
@@ -1,4 +1,5 @@ # Placeholder for internal Python strict library compatibility macro. +load("//tensorflow_lite_support/cc/port:build_defs.bzl", "support_cc_proto_library", "support_py_proto_library") package( default_visibility = ["//tensorflow_lite_support:internal"], @@ -9,7 +10,9 @@ name = "embedding_pb2", srcs = ["embedding_pb2.py"], deps = [ + # build rule placeholder: numpy dep, "//tensorflow_lite_support/cc/task/processor/proto:embedding_py_pb2", + "//tensorflow_lite_support/python/task/core:optional_dependencies", ], ) @@ -18,6 +21,7 @@ srcs = ["embedding_options_pb2.py"], deps = [ "//tensorflow_lite_support/cc/task/processor/proto:embedding_options_py_pb2", + "//tensorflow_lite_support/python/task/core:optional_dependencies", ], ) @@ -26,6 +30,7 @@ srcs = ["bounding_box_pb2.py"], deps = [ "//tensorflow_lite_support/cc/task/processor/proto:bounding_box_py_pb2", + "//tensorflow_lite_support/python/task/core:optional_dependencies", ], ) @@ -34,6 +39,7 @@ srcs = ["class_pb2.py"], deps = [ "//tensorflow_lite_support/cc/task/processor/proto:class_py_pb2", + "//tensorflow_lite_support/python/task/core:optional_dependencies", ], ) @@ -41,7 +47,9 @@ name = "classifications_pb2", srcs = ["classifications_pb2.py"], deps = [ + ":class_pb2", "//tensorflow_lite_support/cc/task/processor/proto:classifications_py_pb2", + "//tensorflow_lite_support/python/task/core:optional_dependencies", ], ) @@ -50,6 +58,7 @@ srcs = ["classification_options_pb2.py"], deps = [ "//tensorflow_lite_support/cc/task/processor/proto:classification_options_py_pb2", + "//tensorflow_lite_support/python/task/core:optional_dependencies", ], ) @@ -57,7 +66,10 @@ name = "detections_pb2", srcs = ["detections_pb2.py"], deps = [ + ":bounding_box_pb2", + ":class_pb2", "//tensorflow_lite_support/cc/task/processor/proto:detections_py_pb2", + "//tensorflow_lite_support/python/task/core:optional_dependencies", ], ) @@ -66,6 +78,7 @@ srcs = ["detection_options_pb2.py"], deps = [ "//tensorflow_lite_support/cc/task/processor/proto:detection_options_py_pb2", + "//tensorflow_lite_support/python/task/core:optional_dependencies", ], ) @@ -74,6 +87,8 @@ srcs = ["segmentations_pb2.py"], deps = [ "//tensorflow_lite_support/cc/task/vision/proto:segmentations_py_pb2", + # build rule placeholder: numpy dep, + "//tensorflow_lite_support/python/task/core:optional_dependencies", ], ) @@ -82,5 +97,33 @@ srcs = ["segmentation_options_pb2.py"], deps = [ "//tensorflow_lite_support/cc/task/processor/proto:segmentation_options_py_pb2", + "//tensorflow_lite_support/python/task/core:optional_dependencies", + ], +) + +proto_library( + name = "search_options_proto", + srcs = ["search_options.proto"], +) + +support_cc_proto_library( + name = "search_options_cc_proto", + deps = [ + ":search_options_proto", + ], +) + +support_py_proto_library( + name = "search_options_py_pb2", + srcs = ["search_options.proto"], + api_version = 2, + proto_deps = [":search_options_proto"], +) + +py_library( + name = "search_result_pb2", + srcs = ["search_result_pb2.py"], + deps = [ + "//tensorflow_lite_support/cc/task/processor/proto:search_result_py_pb2", ], )
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/processor/proto/bounding_box_pb2.py b/third_party/tflite_support/src/tensorflow_lite_support/python/task/processor/proto/bounding_box_pb2.py index e1d0485..969a7a6 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/python/task/processor/proto/bounding_box_pb2.py +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/processor/proto/bounding_box_pb2.py
@@ -13,6 +13,60 @@ # limitations under the License. """Bounding box protobuf.""" -from tensorflow_lite_support.cc.task.processor.proto import bounding_box_pb2 +import dataclasses +from typing import Any -BoundingBox = bounding_box_pb2.BoundingBox +from tensorflow_lite_support.cc.task.processor.proto import bounding_box_pb2 +from tensorflow_lite_support.python.task.core.optional_dependencies import doc_controls + + +@dataclasses.dataclass +class BoundingBox: + """An integer bounding box, axis aligned. + + Attributes: + origin_x: The X coordinate of the top-left corner, in pixels. + origin_y: The Y coordinate of the top-left corner, in pixels. + width: The width of the bounding box, in pixels. + height: The height of the bounding box, in pixels. + """ + + origin_x: int + origin_y: int + width: int + height: int + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> bounding_box_pb2.BoundingBox: + """Generates a protobuf object to pass to the C++ layer.""" + return bounding_box_pb2.BoundingBox( + origin_x=self.origin_x, + origin_y=self.origin_y, + width=self.width, + height=self.height, + ) + + @classmethod + @doc_controls.do_not_generate_docs + def create_from_pb2(cls, + pb2_obj: bounding_box_pb2.BoundingBox) -> "BoundingBox": + """Creates a `BoundingBox` object from the given protobuf object.""" + return BoundingBox( + origin_x=pb2_obj.origin_x, + origin_y=pb2_obj.origin_y, + width=pb2_obj.width, + height=pb2_obj.height) + + def __eq__(self, other: Any) -> bool: + """Checks if this object is equal to the given object. + + Args: + other: The object to be compared with. + + Returns: + True if the objects are equal. + """ + if not isinstance(other, BoundingBox): + return False + + return self.to_pb2().__eq__(other.to_pb2())
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/processor/proto/class_pb2.py b/third_party/tflite_support/src/tensorflow_lite_support/python/task/processor/proto/class_pb2.py index 1df20f7..2492b30 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/python/task/processor/proto/class_pb2.py +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/processor/proto/class_pb2.py
@@ -13,6 +13,66 @@ # limitations under the License. """Class protobuf.""" -from tensorflow_lite_support.cc.task.processor.proto import class_pb2 +import dataclasses +from typing import Any -Category = class_pb2.Class +from tensorflow_lite_support.cc.task.processor.proto import class_pb2 +from tensorflow_lite_support.python.task.core.optional_dependencies import doc_controls + +_ClassProto = class_pb2.Class + + +@dataclasses.dataclass +class Category: + """A classification category. + + Category is a util class, contains a label, its display name, a float + value as score, and the index of the label in the corresponding label file. + Typically it's used as the result of classification tasks. + + Attributes: + index: The index of the label in the corresponding label file. + score: The probability score of this label category. + display_name: The display name of the label, which may be translated for + different locales. For example, a label, "apple", may be translated into + Spanish for display purpose, so that the `display_name` is "manzana". + category_name: The label of this category object. + """ + + index: int + score: float + display_name: str + category_name: str + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _ClassProto: + """Generates a protobuf object to pass to the C++ layer.""" + return _ClassProto( + index=self.index, + score=self.score, + display_name=self.display_name, + class_name=self.category_name) + + @classmethod + @doc_controls.do_not_generate_docs + def create_from_pb2(cls, pb2_obj: _ClassProto) -> "Category": + """Creates a `Category` object from the given protobuf object.""" + return Category( + index=pb2_obj.index, + score=pb2_obj.score, + display_name=pb2_obj.display_name, + category_name=pb2_obj.class_name) + + def __eq__(self, other: Any) -> bool: + """Checks if this object is equal to the given object. + + Args: + other: The object to be compared with. + + Returns: + True if the objects are equal. + """ + if not isinstance(other, Category): + return False + + return self.to_pb2().__eq__(other.to_pb2())
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/processor/proto/classification_options_pb2.py b/third_party/tflite_support/src/tensorflow_lite_support/python/task/processor/proto/classification_options_pb2.py index f346d405..ea9ba35 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/python/task/processor/proto/classification_options_pb2.py +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/processor/proto/classification_options_pb2.py
@@ -13,6 +13,76 @@ # limitations under the License. """Classification options protobuf.""" -from tensorflow_lite_support.cc.task.processor.proto import classification_options_pb2 +import dataclasses +from typing import Any, List, Optional -ClassificationOptions = classification_options_pb2.ClassificationOptions +from tensorflow_lite_support.cc.task.processor.proto import classification_options_pb2 +from tensorflow_lite_support.python.task.core.optional_dependencies import doc_controls + +_ClassificationOptionsProto = classification_options_pb2.ClassificationOptions + + +@dataclasses.dataclass +class ClassificationOptions: + """Options for classification processor. + + Attributes: + display_names_locale: The locale to use for display names specified through + the TFLite Model Metadata. + max_results: The maximum number of top-scored classification results to + return. + score_threshold: Overrides the ones provided in the model metadata. Results + below this value are rejected. + category_name_allowlist: If non-empty, classifications whose class name is + not in this set will be filtered out. Duplicate or unknown class names are + ignored. Mutually exclusive with `category_name_denylist`. + category_name_denylist: If non-empty, classifications whose class name is in + this set will be filtered out. Duplicate or unknown class names are + ignored. Mutually exclusive with `category_name_allowlist`. + """ + + score_threshold: Optional[float] = None + category_name_allowlist: Optional[List[str]] = None + category_name_denylist: Optional[List[str]] = None + display_names_locale: Optional[str] = None + max_results: Optional[int] = None + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _ClassificationOptionsProto: + """Generates a protobuf object to pass to the C++ layer.""" + return _ClassificationOptionsProto( + score_threshold=self.score_threshold, + class_name_allowlist=self.category_name_allowlist, + class_name_denylist=self.category_name_denylist, + display_names_locale=self.display_names_locale, + max_results=self.max_results) + + @classmethod + @doc_controls.do_not_generate_docs + def create_from_pb2( + cls, pb2_obj: _ClassificationOptionsProto) -> "ClassificationOptions": + """Creates a `ClassificationOptions` object from the given protobuf object.""" + return ClassificationOptions( + score_threshold=pb2_obj.score_threshold, + category_name_allowlist=[ + str(name) for name in pb2_obj.class_name_allowlist + ], + category_name_denylist=[ + str(name) for name in pb2_obj.class_name_denylist + ], + display_names_locale=pb2_obj.display_names_locale, + max_results=pb2_obj.max_results) + + def __eq__(self, other: Any) -> bool: + """Checks if this object is equal to the given object. + + Args: + other: The object to be compared with. + + Returns: + True if the objects are equal. + """ + if not isinstance(other, ClassificationOptions): + return False + + return self.to_pb2().__eq__(other.to_pb2())
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/processor/proto/classifications_pb2.py b/third_party/tflite_support/src/tensorflow_lite_support/python/task/processor/proto/classifications_pb2.py index 9dbf15b..40f2d51 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/python/task/processor/proto/classifications_pb2.py +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/processor/proto/classifications_pb2.py
@@ -13,7 +13,106 @@ # limitations under the License. """Classifications protobuf.""" -from tensorflow_lite_support.cc.task.processor.proto import classifications_pb2 +import dataclasses +from typing import Any, List -Classifications = classifications_pb2.Classifications -ClassificationResult = classifications_pb2.ClassificationResult +from tensorflow_lite_support.cc.task.processor.proto import classifications_pb2 +from tensorflow_lite_support.python.task.core.optional_dependencies import doc_controls +from tensorflow_lite_support.python.task.processor.proto import class_pb2 + +_ClassificationsProto = classifications_pb2.Classifications +_ClassificationResultProto = classifications_pb2.ClassificationResult + + +@dataclasses.dataclass +class Classifications: + """List of predicted classes (aka labels) for a given classifier head. + + Attributes: + categories: The array of predicted categories, usually sorted by descending + scores (e.g. from high to low probability). + head_index: The index of the classifier head these categories refer to. This + is useful for multi-head models. + head_name: The name of the classifier head, which is the corresponding + tensor metadata. + """ + + categories: List[class_pb2.Category] + head_index: int + head_name: str + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _ClassificationsProto: + """Generates a protobuf object to pass to the C++ layer.""" + return _ClassificationsProto( + classes=[category.to_pb2() for category in self.categories], + head_index=self.head_index, + head_name=self.head_name) + + @classmethod + @doc_controls.do_not_generate_docs + def create_from_pb2(cls, pb2_obj: _ClassificationsProto) -> "Classifications": + """Creates a `Classifications` object from the given protobuf object.""" + return Classifications( + categories=[ + class_pb2.Category.create_from_pb2(category) + for category in pb2_obj.classes + ], + head_index=pb2_obj.head_index, + head_name=pb2_obj.head_name) + + def __eq__(self, other: Any) -> bool: + """Checks if this object is equal to the given object. + + Args: + other: The object to be compared with. + + Returns: + True if the objects are equal. + """ + if not isinstance(other, Classifications): + return False + + return self.to_pb2().__eq__(other.to_pb2()) + + +@dataclasses.dataclass +class ClassificationResult: + """Contains one set of results per classifier head. + + Attributes: + classifications: A list of `Classifications` objects. + """ + + classifications: List[Classifications] + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _ClassificationResultProto: + """Generates a protobuf object to pass to the C++ layer.""" + return _ClassificationResultProto(classifications=[ + classification.to_pb2() for classification in self.classifications + ]) + + @classmethod + @doc_controls.do_not_generate_docs + def create_from_pb2( + cls, pb2_obj: _ClassificationResultProto) -> "ClassificationResult": + """Creates a `ClassificationResult` object from the given protobuf object.""" + return ClassificationResult(classifications=[ + Classifications.create_from_pb2(classification) + for classification in pb2_obj.classifications + ]) + + def __eq__(self, other: Any) -> bool: + """Checks if this object is equal to the given object. + + Args: + other: The object to be compared with. + + Returns: + True if the objects are equal. + """ + if not isinstance(other, ClassificationResult): + return False + + return self.to_pb2().__eq__(other.to_pb2())
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/processor/proto/detection_options_pb2.py b/third_party/tflite_support/src/tensorflow_lite_support/python/task/processor/proto/detection_options_pb2.py index 3a98a48..dee2cd7 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/python/task/processor/proto/detection_options_pb2.py +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/processor/proto/detection_options_pb2.py
@@ -13,6 +13,76 @@ # limitations under the License. """Detection options protobuf.""" -from tensorflow_lite_support.cc.task.processor.proto import detection_options_pb2 +import dataclasses +from typing import Any, List, Optional -DetectionOptions = detection_options_pb2.DetectionOptions +from tensorflow_lite_support.cc.task.processor.proto import detection_options_pb2 +from tensorflow_lite_support.python.task.core.optional_dependencies import doc_controls + +_DetectionOptionsProto = detection_options_pb2.DetectionOptions + + +@dataclasses.dataclass +class DetectionOptions: + """Options for object detection processor. + + Attributes: + display_names_locale: The locale to use for display names specified through + the TFLite Model Metadata. + max_results: The maximum number of top-scored classification results to + return. + score_threshold: Overrides the ones provided in the model metadata. Results + below this value are rejected. + category_name_allowlist: If non-empty, classifications whose class name is + not in this set will be filtered out. Duplicate or unknown class names are + ignored. Mutually exclusive with `category_name_denylist`. + category_name_denylist: If non-empty, classifications whose class name is in + this set will be filtered out. Duplicate or unknown class names are + ignored. Mutually exclusive with `category_name_allowlist`. + """ + + score_threshold: Optional[float] = None + category_name_allowlist: Optional[List[str]] = None + category_name_denylist: Optional[List[str]] = None + display_names_locale: Optional[str] = None + max_results: Optional[int] = None + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _DetectionOptionsProto: + """Generates a protobuf object to pass to the C++ layer.""" + return _DetectionOptionsProto( + score_threshold=self.score_threshold, + class_name_allowlist=self.category_name_allowlist, + class_name_denylist=self.category_name_denylist, + display_names_locale=self.display_names_locale, + max_results=self.max_results) + + @classmethod + @doc_controls.do_not_generate_docs + def create_from_pb2(cls, + pb2_obj: _DetectionOptionsProto) -> "DetectionOptions": + """Creates a `DetectionOptions` object from the given protobuf object.""" + return DetectionOptions( + score_threshold=pb2_obj.score_threshold, + category_name_allowlist=[ + str(name) for name in pb2_obj.class_name_allowlist + ], + category_name_denylist=[ + str(name) for name in pb2_obj.class_name_denylist + ], + display_names_locale=pb2_obj.display_names_locale, + max_results=pb2_obj.max_results) + + def __eq__(self, other: Any) -> bool: + """Checks if this object is equal to the given object. + + Args: + other: The object to be compared with. + + Returns: + True if the objects are equal. + """ + if not isinstance(other, DetectionOptions): + return False + + return self.to_pb2().__eq__(other.to_pb2())
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/processor/proto/detections_pb2.py b/third_party/tflite_support/src/tensorflow_lite_support/python/task/processor/proto/detections_pb2.py index 5ac9d568..50f8693 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/python/task/processor/proto/detections_pb2.py +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/processor/proto/detections_pb2.py
@@ -13,7 +13,97 @@ # limitations under the License. """Detections protobuf.""" -from tensorflow_lite_support.cc.task.processor.proto import detections_pb2 +import dataclasses +from typing import Any, List -Detection = detections_pb2.Detection -DetectionResult = detections_pb2.DetectionResult +from tensorflow_lite_support.cc.task.processor.proto import detections_pb2 +from tensorflow_lite_support.python.task.core.optional_dependencies import doc_controls +from tensorflow_lite_support.python.task.processor.proto import bounding_box_pb2 +from tensorflow_lite_support.python.task.processor.proto import class_pb2 + +_DetectionProto = detections_pb2.Detection +_DetectionResultProto = detections_pb2.DetectionResult + + +@dataclasses.dataclass +class Detection: + """Represents one detected object in the object detector's results. + + Attributes: + bounding_box: A `bounding_box_pb2.BoundingBox` object. + categories: A list of `class_pb2.Category` objects. + """ + + bounding_box: bounding_box_pb2.BoundingBox + categories: List[class_pb2.Category] + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _DetectionProto: + """Generates a protobuf object to pass to the C++ layer.""" + return _DetectionProto( + bounding_box=self.bounding_box, + classes=[category.to_pb2() for category in self.categories]) + + @classmethod + @doc_controls.do_not_generate_docs + def create_from_pb2(cls, pb2_obj: _DetectionProto) -> "Detection": + """Creates a `Detection` object from the given protobuf object.""" + return Detection( + bounding_box=pb2_obj.bounding_box, + categories=[ + class_pb2.Category.create_from_pb2(category) + for category in pb2_obj.classes + ]) + + def __eq__(self, other: Any) -> bool: + """Checks if this object is equal to the given object. + + Args: + other: The object to be compared with. + + Returns: + True if the objects are equal. + """ + if not isinstance(other, Detection): + return False + + return self.to_pb2().__eq__(other.to_pb2()) + + +@dataclasses.dataclass +class DetectionResult: + """Represents the list of detected objects. + + Attributes: + detections: A list of `Detection` objects. + """ + + detections: List[Detection] + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _DetectionResultProto: + """Generates a protobuf object to pass to the C++ layer.""" + return _DetectionResultProto( + detections=[detection.to_pb2() for detection in self.detections]) + + @classmethod + @doc_controls.do_not_generate_docs + def create_from_pb2(cls, pb2_obj: _DetectionResultProto) -> "DetectionResult": + """Creates a `DetectionResult` object from the given protobuf object.""" + return DetectionResult(detections=[ + Detection.create_from_pb2(detection) for detection in pb2_obj.detections + ]) + + def __eq__(self, other: Any) -> bool: + """Checks if this object is equal to the given object. + + Args: + other: The object to be compared with. + + Returns: + True if the objects are equal. + """ + if not isinstance(other, DetectionResult): + return False + + return self.to_pb2().__eq__(other.to_pb2())
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/processor/proto/embedding_options_pb2.py b/third_party/tflite_support/src/tensorflow_lite_support/python/task/processor/proto/embedding_options_pb2.py index 44f8e6cc..c3c472a 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/python/task/processor/proto/embedding_options_pb2.py +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/processor/proto/embedding_options_pb2.py
@@ -13,6 +13,57 @@ # limitations under the License. """Embedding options protobuf.""" -from tensorflow_lite_support.cc.task.processor.proto import embedding_options_pb2 +import dataclasses +from typing import Any, Optional -EmbeddingOptions = embedding_options_pb2.EmbeddingOptions +from tensorflow_lite_support.cc.task.processor.proto import embedding_options_pb2 +from tensorflow_lite_support.python.task.core.optional_dependencies import doc_controls + +_EmbeddingOptionsProto = embedding_options_pb2.EmbeddingOptions + + +@dataclasses.dataclass +class EmbeddingOptions: + """Options for embedding processor. + + Attributes: + l2_normalize: Whether to normalize the returned feature vector with L2 norm. + Use this option only if the model does not already contain a native + L2_NORMALIZATION TF Lite Op. In most cases, this is already the case and + L2 norm is thus achieved through TF Lite inference. + quantize: Whether the returned embedding should be quantized to bytes via + scalar quantization. Embeddings are implicitly assumed to be unit-norm and + therefore any dimension is guaranteed to have a value in [-1.0, 1.0]. Use + the l2_normalize option if this is not the case. + """ + + l2_normalize: Optional[bool] = None + quantize: Optional[bool] = None + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _EmbeddingOptionsProto: + """Generates a protobuf object to pass to the C++ layer.""" + return _EmbeddingOptionsProto( + l2_normalize=self.l2_normalize, quantize=self.quantize) + + @classmethod + @doc_controls.do_not_generate_docs + def create_from_pb2(cls, + pb2_obj: _EmbeddingOptionsProto) -> "EmbeddingOptions": + """Creates a `EmbeddingOptions` object from the given protobuf object.""" + return EmbeddingOptions( + l2_normalize=pb2_obj.l2_normalize, quantize=pb2_obj.quantize) + + def __eq__(self, other: Any) -> bool: + """Checks if this object is equal to the given object. + + Args: + other: The object to be compared with. + + Returns: + True if the objects are equal. + """ + if not isinstance(other, EmbeddingOptions): + return False + + return self.to_pb2().__eq__(other.to_pb2())
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/processor/proto/embedding_pb2.py b/third_party/tflite_support/src/tensorflow_lite_support/python/task/processor/proto/embedding_pb2.py index feb4887c..3ca5b14 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/python/task/processor/proto/embedding_pb2.py +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/processor/proto/embedding_pb2.py
@@ -13,8 +13,156 @@ # limitations under the License. """Embedding result protobuf.""" -from tensorflow_lite_support.cc.task.processor.proto import embedding_pb2 +import dataclasses +from typing import Any, List -FeatureVector = embedding_pb2.FeatureVector -Embedding = embedding_pb2.Embedding -EmbeddingResult = embedding_pb2.EmbeddingResult +import numpy as np +from tensorflow_lite_support.cc.task.processor.proto import embedding_pb2 +from tensorflow_lite_support.python.task.core.optional_dependencies import doc_controls + +_FeatureVectorProto = embedding_pb2.FeatureVector +_EmbeddingProto = embedding_pb2.Embedding +_EmbeddingResultProto = embedding_pb2.EmbeddingResult + + +@dataclasses.dataclass +class FeatureVector: + """A dense feature vector. + + Only one of the two fields is ever present. + Feature vectors are assumed to be one-dimensional and L2-normalized. + + Attributes: + value: A NumPy array indidcating the raw output of the embedding layer. The + datatype of elements in the array can be either float or uint8 if + `quantize` is set to True in `EmbeddingOptions`. + """ + + value: np.ndarray + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _FeatureVectorProto: + """Generates a protobuf object to pass to the C++ layer.""" + + if self.value.dtype == float: + return _FeatureVectorProto(value_float=self.value) + + elif self.value.dtype == np.uint8: + return _FeatureVectorProto(value_string=bytes(self.value)) + + else: + raise ValueError("Invalid dtype. Only float and np.uint8 are supported.") + + @classmethod + @doc_controls.do_not_generate_docs + def create_from_pb2(cls, pb2_obj: _FeatureVectorProto) -> "FeatureVector": + """Creates a `FeatureVector` object from the given protobuf object.""" + + if pb2_obj.value_float: + return FeatureVector( + value=np.array(pb2_obj.value_float, dtype=float)) + + elif pb2_obj.value_string: + return FeatureVector( + value=np.array(bytearray(pb2_obj.value_string), dtype=np.uint8)) + + else: + raise ValueError("Either value_float or value_string must exist.") + + def __eq__(self, other: Any) -> bool: + """Checks if this object is equal to the given object. + + Args: + other: The object to be compared with. + + Returns: + True if the objects are equal. + """ + if not isinstance(other, FeatureVector): + return False + + return self.to_pb2().__eq__(other.to_pb2()) + + +@dataclasses.dataclass +class Embedding: + """Result produced by one of the embedder model output layers. + + Attributes: + feature_vector: The output feature vector. + output_index: The index of the model output layer that produced this feature + vector. + """ + + feature_vector: FeatureVector + output_index: int + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _EmbeddingProto: + """Generates a protobuf object to pass to the C++ layer.""" + return _EmbeddingProto( + feature_vector=self.feature_vector.to_pb2(), + output_index=self.output_index) + + @classmethod + @doc_controls.do_not_generate_docs + def create_from_pb2(cls, pb2_obj: _EmbeddingProto) -> "Embedding": + """Creates a `Embedding` object from the given protobuf object.""" + return Embedding( + feature_vector=FeatureVector.create_from_pb2(pb2_obj.feature_vector), + output_index=pb2_obj.output_index) + + def __eq__(self, other: Any) -> bool: + """Checks if this object is equal to the given object. + + Args: + other: The object to be compared with. + + Returns: + True if the objects are equal. + """ + if not isinstance(other, Embedding): + return False + + return self.to_pb2().__eq__(other.to_pb2()) + + +@dataclasses.dataclass +class EmbeddingResult: + """Embeddings produced by the Embedder. + + Attributes: + embeddings: The embeddings produced by each of the model output layers. + Except in advanced cases, the embedding model has a single output layer, + and this list is thus made of a single element feature vector. + """ + + embeddings: List[Embedding] + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _EmbeddingResultProto: + """Generates a protobuf object to pass to the C++ layer.""" + return _EmbeddingResultProto( + embeddings=[embedding.to_pb2() for embedding in self.embeddings]) + + @classmethod + @doc_controls.do_not_generate_docs + def create_from_pb2(cls, pb2_obj: _EmbeddingResultProto) -> "EmbeddingResult": + """Creates a `EmbeddingResult` object from the given protobuf object.""" + return EmbeddingResult(embeddings=[ + Embedding.create_from_pb2(embedding) for embedding in pb2_obj.embeddings + ]) + + def __eq__(self, other: Any) -> bool: + """Checks if this object is equal to the given object. + + Args: + other: The object to be compared with. + + Returns: + True if the objects are equal. + """ + if not isinstance(other, EmbeddingResult): + return False + + return self.to_pb2().__eq__(other.to_pb2())
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/processor/proto/search_options.proto b/third_party/tflite_support/src/tensorflow_lite_support/python/task/processor/proto/search_options.proto new file mode 100644 index 0000000..618f42133 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/processor/proto/search_options.proto
@@ -0,0 +1,46 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package tflite.python.task.processor; + +option java_multiple_files = true; +option java_package = "org.tensorflow.lite.task.processor.proto"; + +// Options for Python search processor. +// See C++ search options at: +// https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/cc/task/processor/proto/search_options.proto +// Next Id: 4 +message SearchOptions { + // The index file to search into. Mandatory only if the index is not attached + // to the output tensor metadata as an AssociatedFile with type + // SCANN_INDEX_FILE. + // The index file can be specified by one of the following two ways: + // + // (1) file contents loaded in `index_file_content`. + // (2) file path in `index_file_name`. + // + // If more than one field of these fields is provided, they are used in this + // precedence order. + // + // The path to the index file to open and mmap in memory. + optional string index_file_name = 1; + // The index file contents as a byte array. + optional bytes index_file_content = 2; + + // Maximum number of nearest neighbor results to return. + optional int32 max_results = 3 [default = 5]; +}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/processor/proto/search_result_pb2.py b/third_party/tflite_support/src/tensorflow_lite_support/python/task/processor/proto/search_result_pb2.py new file mode 100644 index 0000000..893770f --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/processor/proto/search_result_pb2.py
@@ -0,0 +1,19 @@ +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Search result protobuf.""" + +from tensorflow_lite_support.cc.task.processor.proto import search_result_pb2 + +SearchResult = search_result_pb2.SearchResult +NearestNeighbor = search_result_pb2.NearestNeighbor
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/processor/proto/segmentation_options_pb2.py b/third_party/tflite_support/src/tensorflow_lite_support/python/task/processor/proto/segmentation_options_pb2.py index 4511944..d90320b 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/python/task/processor/proto/segmentation_options_pb2.py +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/processor/proto/segmentation_options_pb2.py
@@ -13,6 +13,62 @@ # limitations under the License. """Segmentation options protobuf.""" -from tensorflow_lite_support.cc.task.processor.proto import segmentation_options_pb2 +import dataclasses +import enum +from typing import Any, Optional -SegmentationOptions = segmentation_options_pb2.SegmentationOptions +from tensorflow_lite_support.cc.task.processor.proto import segmentation_options_pb2 +from tensorflow_lite_support.python.task.core.optional_dependencies import doc_controls + +_SegmentationOptionsProto = segmentation_options_pb2.SegmentationOptions + + +class OutputType(enum.Enum): + UNSPECIFIED = 0 + CATEGORY_MASK = 1 + CONFIDENCE_MASK = 2 + + +@dataclasses.dataclass +class SegmentationOptions: + """Options for segmentation processor. + + Attributes: + display_names_locale: The locale to use for display names specified through + the TFLite Model Metadata. + output_type: The output mask type allows specifying the type of + post-processing to perform on the raw model results. + """ + + display_names_locale: Optional[str] = None + output_type: Optional[OutputType] = OutputType.CATEGORY_MASK + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _SegmentationOptionsProto: + """Generates a protobuf object to pass to the C++ layer.""" + return _SegmentationOptionsProto( + display_names_locale=self.display_names_locale, + output_type=self.output_type.value) + + @classmethod + @doc_controls.do_not_generate_docs + def create_from_pb2( + cls, pb2_obj: _SegmentationOptionsProto) -> "SegmentationOptions": + """Creates a `SegmentationOptions` object from the given protobuf object.""" + return SegmentationOptions( + display_names_locale=pb2_obj.display_names_locale, + output_type=OutputType(pb2_obj.output_type)) + + def __eq__(self, other: Any) -> bool: + """Checks if this object is equal to the given object. + + Args: + other: The object to be compared with. + + Returns: + True if the objects are equal. + """ + if not isinstance(other, SegmentationOptions): + return False + + return self.to_pb2().__eq__(other.to_pb2())
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/processor/proto/segmentations_pb2.py b/third_party/tflite_support/src/tensorflow_lite_support/python/task/processor/proto/segmentations_pb2.py index 5d99cb6..ea5b455e 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/python/task/processor/proto/segmentations_pb2.py +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/processor/proto/segmentations_pb2.py
@@ -13,7 +13,252 @@ # limitations under the License. """Segmentations protobuf.""" +import dataclasses +from typing import Any, Tuple, List, Optional + +import numpy as np +from tensorflow_lite_support.python.task.core.optional_dependencies import doc_controls + +# Using the proto in vision.proto here instead of processor.proto to match with +# the C++ layer. It's to avoid converting the large confidence masks or category +# mask from the proto type defined in vision.proto to processor.proto. For other +# tasks, the proto in processor.proto is always used in the Python layer and +# vision.proto <-> processor.proto's proto conversion happens in the C++ layer +# as those conversions of small protobuf objects are trivial. from tensorflow_lite_support.cc.task.vision.proto import segmentations_pb2 -Segmentation = segmentations_pb2.Segmentation -SegmentationResult = segmentations_pb2.SegmentationResult +_SegmentationProto = segmentations_pb2.Segmentation +_ConfidenceMaskProto = segmentations_pb2.Segmentation.ConfidenceMask +_ColoredLabelProto = segmentations_pb2.Segmentation.ColoredLabel +_SegmentationResultProto = segmentations_pb2.SegmentationResult + + +@dataclasses.dataclass +class ConfidenceMask: + """2D-array representing the confidence mask in row major order. + + For each pixel, the value indicates the prediction confidence usually + in the [0, 1] range where higher values represent a stronger confidence. + Ultimately this is model specific, and other range of values might be used. + + Attributes: + value: A NumPy 2D-array indicating the prediction confidence values usually + in the range [0, 1]. + """ + value: np.ndarray + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _ConfidenceMaskProto: + """Generates a protobuf object to pass to the C++ layer.""" + return _ConfidenceMaskProto(value=self.value.flatten()) + + @classmethod + @doc_controls.do_not_generate_docs + def create_from_pb2(cls, pb2_obj: _ConfidenceMaskProto, height: int, + width: int) -> "ConfidenceMask": + """Creates a `ConfidenceMask` object from the given protobuf and size.""" + return ConfidenceMask(value=np.array(pb2_obj.value).reshape(height, width)) + + def __eq__(self, other: Any) -> bool: + """Checks if this object is equal to the given object. + + Args: + other: The object to be compared with. + + Returns: + True if the objects are equal. + """ + if not isinstance(other, ConfidenceMask): + return False + + return self.to_pb2().__eq__(other.to_pb2()) + + +@dataclasses.dataclass +class ColoredLabel: + """Defines a label associated with an RGB color, for display purposes. + + Attributes: + color: The RGB color components for the label, in the [0, 255] range. + category_name: The class name, as provided in the label map packed in the + TFLite ModelMetadata. + display_name: The display name, as provided in the label map (if available) + packed in the TFLite Model Metadata . + """ + + color: Tuple[int, int, int] + category_name: str + display_name: str + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _ColoredLabelProto: + """Generates a protobuf object to pass to the C++ layer.""" + r, g, b = self.color + return _ColoredLabelProto( + r=r, + g=g, + b=b, + class_name=self.category_name, + display_name=self.display_name) + + @classmethod + @doc_controls.do_not_generate_docs + def create_from_pb2(cls, pb2_obj: _ColoredLabelProto) -> "ColoredLabel": + """Creates a `ColoredLabel` object from the given protobuf object.""" + return ColoredLabel( + color=(pb2_obj.r, pb2_obj.g, pb2_obj.b), + category_name=pb2_obj.class_name, + display_name=pb2_obj.display_name) + + def __eq__(self, other: Any) -> bool: + """Checks if this object is equal to the given object. + + Args: + other: The object to be compared with. + + Returns: + True if the objects are equal. + """ + if not isinstance(other, ColoredLabel): + return False + + return self.to_pb2().__eq__(other.to_pb2()) + + +@dataclasses.dataclass +class Segmentation: + """Represents one Segmentation object in the image segmenter's results. + + Attributes: + height: The height of the mask. This is an intrinsic parameter of the model + being used, and does not depend on the input image dimensions. + width: The width of the mask. This is an intrinsic parameter of the model + being used, and does not depend on the input image dimensions. + colored_labels: A list of `ColoredLabel` objects. + category_mask: A NumPy 2D-array of the category mask. + confidence_masks: A list of `ConfidenceMask` objects. + """ + + height: int + width: int + colored_labels: List[ColoredLabel] + category_mask: Optional[np.ndarray] = None + confidence_masks: Optional[List[ConfidenceMask]] = None + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _SegmentationProto: + """Generates a protobuf object to pass to the C++ layer.""" + + if self.category_mask is not None: + return _SegmentationProto( + height=self.height, + width=self.width, + category_mask=bytes(self.category_mask), + colored_labels=[ + colored_label.to_pb2() for colored_label in self.colored_labels + ]) + + elif self.confidence_masks is not None: + segmentation_proto = _SegmentationProto() + segmentation_proto.height = self.height + segmentation_proto.width = self.width + segmentation_proto.confidence_masks.confidence_mask.extend([ + confidence_mask.to_pb2() for confidence_mask in self.confidence_masks + ]) + segmentation_proto.colored_labels.extend( + [colored_label.to_pb2() for colored_label in self.colored_labels]) + return segmentation_proto + else: + raise ValueError("Either category_mask or confidence_masks must be set.") + + @classmethod + @doc_controls.do_not_generate_docs + def create_from_pb2(cls, pb2_obj: _SegmentationProto) -> "Segmentation": + """Creates a `Segmentation` object from the given protobuf object.""" + + if pb2_obj.category_mask: + return Segmentation( + height=pb2_obj.height, + width=pb2_obj.width, + category_mask=np.array(bytearray(pb2_obj.category_mask)).reshape( + pb2_obj.height, pb2_obj.width), + colored_labels=[ + ColoredLabel.create_from_pb2(colored_label) + for colored_label in pb2_obj.colored_labels + ]) + + elif pb2_obj.confidence_masks.confidence_mask: + confidence_masks = [ + ConfidenceMask.create_from_pb2(mask, pb2_obj.height, pb2_obj.width) + for mask in pb2_obj.confidence_masks.confidence_mask + ] + return Segmentation( + height=pb2_obj.height, + width=pb2_obj.width, + confidence_masks=confidence_masks, + colored_labels=[ + ColoredLabel.create_from_pb2(colored_label) + for colored_label in pb2_obj.colored_labels + ]) + else: + raise ValueError("Either category_mask or confidence_masks must be set.") + + def __eq__(self, other: Any) -> bool: + """Checks if this object is equal to the given object. + + Args: + other: The object to be compared with. + + Returns: + True if the objects are equal. + """ + if not isinstance(other, Segmentation): + return False + + return self.to_pb2().__eq__(other.to_pb2()) + + +@dataclasses.dataclass +class SegmentationResult: + """Results of performing image segmentation. + + Note that at the time, a single `Segmentation` element is expected to be + returned; the field is made repeated for later extension to e.g. instance + segmentation models, which may return one segmentation per object. + + Attributes: + segmentations: A list of `Segmentation` objects. + """ + + segmentations: List[Segmentation] + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _SegmentationResultProto: + """Generates a protobuf object to pass to the C++ layer.""" + return _SegmentationResultProto(segmentation=[ + segmentation.to_pb2() for segmentation in self.segmentations + ]) + + @classmethod + @doc_controls.do_not_generate_docs + def create_from_pb2( + cls, pb2_obj: _SegmentationResultProto) -> "SegmentationResult": + """Creates a `SegmentationResult` object from the given protobuf object.""" + return SegmentationResult(segmentations=[ + Segmentation.create_from_pb2(segmentation) + for segmentation in pb2_obj.segmentation + ]) + + def __eq__(self, other: Any) -> bool: + """Checks if this object is equal to the given object. + + Args: + other: The object to be compared with. + + Returns: + True if the objects are equal. + """ + if not isinstance(other, SegmentationResult): + return False + + return self.to_pb2().__eq__(other.to_pb2())
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/text/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/python/task/text/BUILD index 1a4f086..447b03c 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/python/task/text/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/text/BUILD
@@ -18,3 +18,18 @@ "//tensorflow_lite_support/python/task/text/pybinds:_pywrap_text_embedder", ], ) + +py_library( + name = "text_searcher", + srcs = [ + "text_searcher.py", + ], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow_lite_support/python/task/core/proto:base_options_py_pb2", + "//tensorflow_lite_support/python/task/processor/proto:embedding_options_pb2", + "//tensorflow_lite_support/python/task/processor/proto:search_options_py_pb2", + "//tensorflow_lite_support/python/task/processor/proto:search_result_pb2", + "//tensorflow_lite_support/python/task/text/pybinds:_pywrap_text_searcher", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/text/pybinds/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/python/task/text/pybinds/BUILD index 6f5a88c4..0d1f4a4 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/python/task/text/pybinds/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/text/pybinds/BUILD
@@ -1,4 +1,4 @@ -load("@org_tensorflow//tensorflow:tensorflow.bzl", "pybind_extension") +load("//tensorflow_lite_support/python/task:build_defs.bzl", "pybind_extension_may_pack_coral") package( default_visibility = [ @@ -7,7 +7,7 @@ licenses = ["notice"], # Apache 2.0 ) -pybind_extension( +pybind_extension_may_pack_coral( name = "_pywrap_text_embedder", srcs = [ "_pywrap_text_embedder.cc", @@ -15,9 +15,26 @@ module_name = "_pywrap_text_embedder", deps = [ "//tensorflow_lite_support/cc/task/text:text_embedder", + "//tensorflow_lite_support/examples/task/text/desktop:universal_sentence_encoder_qa_op_resolver", "//tensorflow_lite_support/python/task/core/pybinds:task_utils", "@pybind11", "@pybind11_abseil//pybind11_abseil:status_casters", "@pybind11_protobuf//pybind11_protobuf:native_proto_caster", ], ) + +pybind_extension_may_pack_coral( + name = "_pywrap_text_searcher", + srcs = [ + "_pywrap_text_searcher.cc", + ], + module_name = "_pywrap_text_searcher", + deps = [ + "//tensorflow_lite_support/cc/task/text:text_searcher", + "//tensorflow_lite_support/examples/task/text/desktop:universal_sentence_encoder_qa_op_resolver", + "//tensorflow_lite_support/python/task/core/pybinds:task_utils", + "//tensorflow_lite_support/python/task/processor/proto:search_options_cc_proto", + "@pybind11", + "@pybind11_protobuf//pybind11_protobuf:native_proto_caster", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/text/pybinds/_pywrap_text_embedder.cc b/third_party/tflite_support/src/tensorflow_lite_support/python/task/text/pybinds/_pywrap_text_embedder.cc new file mode 100644 index 0000000..ac100f0 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/text/pybinds/_pywrap_text_embedder.cc
@@ -0,0 +1,70 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "pybind11/pybind11.h" +#include "pybind11_protobuf/native_proto_caster.h" // from @pybind11_protobuf +#include "tensorflow_lite_support/cc/task/text/text_embedder.h" +#include "tensorflow_lite_support/examples/task/text/desktop/universal_sentence_encoder_qa_op_resolver.h" +#include "tensorflow_lite_support/python/task/core/pybinds/task_utils.h" + +namespace tflite { +namespace task { +namespace text { + +namespace { +using PythonBaseOptions = ::tflite::python::task::core::BaseOptions; +using CppBaseOptions = ::tflite::task::core::BaseOptions; +} // namespace + +PYBIND11_MODULE(_pywrap_text_embedder, m) { + // python wrapper for C++ TextEmbeder class which shouldn't be directly used + // by the users. + pybind11_protobuf::ImportNativeProtoCasters(); + + pybind11::class_<TextEmbedder>(m, "TextEmbedder") + .def_static( + "create_from_options", + [](const PythonBaseOptions& base_options, + const processor::EmbeddingOptions& embedding_options) { + TextEmbedderOptions options; + auto cpp_base_options = + core::convert_to_cpp_base_options(base_options); + + options.set_allocated_base_options(cpp_base_options.release()); + options.add_embedding_options()->CopyFrom(embedding_options); + auto embedder = TextEmbedder::CreateFromOptions( + options, CreateQACustomOpResolver()); + return core::get_value(embedder); + }) + .def("embed", + [](TextEmbedder& self, + const std::string& text) -> processor::EmbeddingResult { + auto embedding_result = self.Embed(text); + return core::get_value(embedding_result); + }) + .def("get_embedding_dimension", &TextEmbedder::GetEmbeddingDimension) + .def("get_number_of_output_layers", + &TextEmbedder::GetNumberOfOutputLayers) + .def_static("cosine_similarity", + [](const processor::FeatureVector& u, + const processor::FeatureVector& v) -> double { + auto similarity = TextEmbedder::CosineSimilarity(u, v); + return core::get_value(similarity); + }); +} + +} // namespace text +} // namespace task +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/text/pybinds/_pywrap_text_searcher.cc b/third_party/tflite_support/src/tensorflow_lite_support/python/task/text/pybinds/_pywrap_text_searcher.cc new file mode 100644 index 0000000..8fc3dd9 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/text/pybinds/_pywrap_text_searcher.cc
@@ -0,0 +1,90 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "pybind11/pybind11.h" +#include "pybind11_protobuf/native_proto_caster.h" // from @pybind11_protobuf +#include "tensorflow_lite_support/cc/task/text/text_searcher.h" +#include "tensorflow_lite_support/examples/task/text/desktop/universal_sentence_encoder_qa_op_resolver.h" +#include "tensorflow_lite_support/python/task/core/pybinds/task_utils.h" +#include "tensorflow_lite_support/python/task/processor/proto/search_options.pb.h" + +namespace tflite { +namespace task { +namespace text { + +namespace { +namespace py = ::pybind11; +using PythonBaseOptions = ::tflite::python::task::core::BaseOptions; +using PythonSearchOptions = ::tflite::python::task::processor::SearchOptions; +using CppBaseOptions = ::tflite::task::core::BaseOptions; +using CppEmbeddingOptions = ::tflite::task::processor::EmbeddingOptions; +using CppSearchOptions = ::tflite::task::processor::SearchOptions; +} // namespace + +PYBIND11_MODULE(_pywrap_text_searcher, m) { + // python wrapper for C++ TextSearcher class which shouldn't be directly used + // by the users. + pybind11_protobuf::ImportNativeProtoCasters(); + + pybind11::class_<TextSearcher>(m, "TextSearcher") + .def_static( + "create_from_options", + [](const PythonBaseOptions& base_options, + const processor::EmbeddingOptions& embedding_options, + const PythonSearchOptions& search_options) { + TextSearcherOptions options; + auto cpp_base_options = + core::convert_to_cpp_base_options(base_options); + options.set_allocated_base_options(cpp_base_options.release()); + + std::unique_ptr<CppEmbeddingOptions> cpp_embedding_options = + std::make_unique<CppEmbeddingOptions>(); + cpp_embedding_options->CopyFrom(embedding_options); + options.set_allocated_embedding_options( + cpp_embedding_options.release()); + + std::unique_ptr<CppSearchOptions> cpp_search_options = + std::make_unique<CppSearchOptions>(); + if (search_options.has_index_file_content()) { + cpp_search_options->mutable_index_file()->set_file_content( + search_options.index_file_content()); + } + if (search_options.has_index_file_name()) { + cpp_search_options->mutable_index_file()->set_file_name( + search_options.index_file_name()); + } + if (search_options.has_max_results()) { + cpp_search_options->set_max_results(search_options.max_results()); + } + + options.set_allocated_search_options(cpp_search_options.release()); + auto searcher = TextSearcher::CreateFromOptions( + options, CreateQACustomOpResolver()); + return core::get_value(searcher); + }) + .def("search", + [](TextSearcher& self, + const std::string& text) -> processor::SearchResult { + auto search_result = self.Search(text); + return core::get_value(search_result); + }) + .def("get_user_info", [](TextSearcher& self) -> py::str { + return py::str(self.GetUserInfo()->data()); + }); +} + +} // namespace text +} // namespace task +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/text/text_embedder.py b/third_party/tflite_support/src/tensorflow_lite_support/python/task/text/text_embedder.py index d4e5b60..322c118c 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/python/task/text/text_embedder.py +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/text/text_embedder.py
@@ -74,8 +74,8 @@ `TextEmbedderOptions` such as missing the model. RuntimeError: If other types of error occurred. """ - embedder = _CppTextEmbedder.create_from_options(options.base_options, - options.embedding_options) + embedder = _CppTextEmbedder.create_from_options( + options.base_options, options.embedding_options.to_pb2()) return cls(options, embedder) def embed(self, text: str) -> embedding_pb2.EmbeddingResult: @@ -91,12 +91,13 @@ ValueError: If any of the input arguments is invalid. RuntimeError: If failed to calculate the embedding vector. """ - return self._embedder.embed(text) + embedding_result = self._embedder.embed(text) + return embedding_pb2.EmbeddingResult.create_from_pb2(embedding_result) def cosine_similarity(self, u: embedding_pb2.FeatureVector, v: embedding_pb2.FeatureVector) -> float: """Computes cosine similarity [1] between two feature vectors.""" - return self._embedder.cosine_similarity(u, v) + return self._embedder.cosine_similarity(u.to_pb2(), v.to_pb2()) def get_embedding_dimension(self, output_index: int) -> int: """Gets the dimensionality of the embedding output.
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/text/text_searcher.py b/third_party/tflite_support/src/tensorflow_lite_support/python/task/text/text_searcher.py new file mode 100644 index 0000000..198c4aa1 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/text/text_searcher.py
@@ -0,0 +1,126 @@ +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Text searcher task.""" + +import dataclasses +from typing import Optional + +from tensorflow_lite_support.python.task.core.proto import base_options_pb2 +from tensorflow_lite_support.python.task.processor.proto import embedding_options_pb2 +from tensorflow_lite_support.python.task.processor.proto import search_options_pb2 +from tensorflow_lite_support.python.task.processor.proto import search_result_pb2 +from tensorflow_lite_support.python.task.text.pybinds import _pywrap_text_searcher + +_CppTextSearcher = _pywrap_text_searcher.TextSearcher +_BaseOptions = base_options_pb2.BaseOptions +_EmbeddingOptions = embedding_options_pb2.EmbeddingOptions +_SearchOptions = search_options_pb2.SearchOptions + + +@dataclasses.dataclass +class TextSearcherOptions: + """Options for the text search task.""" + base_options: _BaseOptions + embedding_options: _EmbeddingOptions = _EmbeddingOptions() + search_options: _SearchOptions = _SearchOptions() + + +class TextSearcher(object): + """Class to performs text search. + + It works by performing embedding extraction on text, followed by + nearest-neighbor search in an index of embeddings through ScaNN. + """ + + def __init__(self, options: TextSearcherOptions, + cpp_searcher: _CppTextSearcher) -> None: + """Initializes the `TextSearcher` object.""" + # Creates the object of C++ TextSearcher class. + self._options = options + self._searcher = cpp_searcher + + @classmethod + def create_from_file(cls, + model_file_path: str, + index_file_path: Optional[str] = None) -> "TextSearcher": + """Creates the `TextSearcher` object from a TensorFlow Lite model. + + Args: + model_file_path: Path to the model. + index_file_path: Path to the index. Only required if the index is not + attached to the output tensor metadata as an AssociatedFile with type + SCANN_INDEX_FILE. + + Returns: + `TextSearcher` object that's created from `options`. + + Raises: + ValueError: If failed to create `TextSearcher` object from the provided + file such as invalid file. + RuntimeError: If other types of error occurred. + """ + options = TextSearcherOptions( + base_options=_BaseOptions(file_name=model_file_path), + search_options=_SearchOptions(index_file_name=index_file_path)) + return cls.create_from_options(options) + + @classmethod + def create_from_options(cls, options: TextSearcherOptions) -> "TextSearcher": + """Creates the `TextSearcher` object from text searcher options. + + Args: + options: Options for the text searcher task. + + Returns: + `TextSearcher` object that's created from `options`. + Raises: + ValueError: If failed to create `TextSearcher` object from + `TextSearcherOptions` such as missing the model. + RuntimeError: If other types of error occurred. + """ + searcher = _CppTextSearcher.create_from_options( + options.base_options, options.embedding_options.to_pb2(), + options.search_options) + return cls(options, searcher) + + def search(self, text: str) -> search_result_pb2.SearchResult: + """Search for text with similar semantic meaning. + + This method performs actual feature extraction on the provided text input, + followed by nearest-neighbor search in the index. + + Args: + text: the input text, used to extract the feature vectors. + + Returns: + search result. + + Raises: + ValueError: If any of the input arguments is invalid. + RuntimeError: If failed to perform nearest-neighbor search. + """ + return self._searcher.search(text) + + def get_user_info(self) -> str: + """Gets the user info stored in the index file. + + Returns: + Opaque user info stored in the index file (if any), in raw binary form. + Returns an empty string if the index doesn't contain user info. + """ + return self._searcher.get_user_info() + + @property + def options(self) -> TextSearcherOptions: + return self._options
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/BUILD index af278d8..664facd 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/BUILD
@@ -44,7 +44,6 @@ ], deps = [ "//tensorflow_lite_support/python/task/core/proto:base_options_py_pb2", - "//tensorflow_lite_support/python/task/processor/proto:bounding_box_pb2", "//tensorflow_lite_support/python/task/processor/proto:segmentation_options_pb2", "//tensorflow_lite_support/python/task/processor/proto:segmentations_pb2", "//tensorflow_lite_support/python/task/vision/core:tensor_image", @@ -54,6 +53,23 @@ ) py_library( + name = "image_searcher", + srcs = [ + "image_searcher.py", + ], + deps = [ + "//tensorflow_lite_support/python/task/core/proto:base_options_py_pb2", + "//tensorflow_lite_support/python/task/processor/proto:bounding_box_pb2", + "//tensorflow_lite_support/python/task/processor/proto:embedding_options_pb2", + "//tensorflow_lite_support/python/task/processor/proto:search_options_py_pb2", + "//tensorflow_lite_support/python/task/processor/proto:search_result_pb2", + "//tensorflow_lite_support/python/task/vision/core:tensor_image", + "//tensorflow_lite_support/python/task/vision/core/pybinds:image_utils", + "//tensorflow_lite_support/python/task/vision/pybinds:_pywrap_image_searcher", + ], +) + +py_library( name = "object_detector", srcs = [ "object_detector.py",
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/core/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/core/BUILD index f033fab..409a5de 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/core/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/core/BUILD
@@ -1,4 +1,3 @@ -# Placeholder for internal Python strict test compatibility macro. # Placeholder for internal Python strict library compatibility macro. package( @@ -20,20 +19,3 @@ "//tensorflow_lite_support/python/task/vision/core/pybinds:image_utils", ], ) - -py_test( - name = "tensor_image_test", - srcs = ["tensor_image_test.py"], - data = [ - "//tensorflow_lite_support/cc/test/testdata/task/vision:test_images", - ], - deps = [ - ":color_space_type", - ":tensor_image", - # build rule placeholder: numpy dep, - # build rule placeholder: tensorflow dep, - "//tensorflow_lite_support/python/task/vision/core/pybinds:image_utils", - "//tensorflow_lite_support/python/test:test_util", - "@absl_py//absl/testing:parameterized", - ], -)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/core/pybinds/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/core/pybinds/BUILD index 5ef08d5..f7eef4b4 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/core/pybinds/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/core/pybinds/BUILD
@@ -14,7 +14,7 @@ ], module_name = "image_utils", deps = [ - "//tensorflow_lite_support/examples/task/vision/desktop/utils:image_utils", + "//tensorflow_lite_support/cc/task/vision/utils:image_utils", "@pybind11", "@pybind11_abseil//pybind11_abseil:status_casters", ],
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/core/pybinds/image_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/core/pybinds/image_utils.cc new file mode 100644 index 0000000..124f5cb --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/core/pybinds/image_utils.cc
@@ -0,0 +1,68 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow_lite_support/cc/task/vision/utils/image_utils.h" + +#include "pybind11/pybind11.h" +#include "pybind11_abseil/status_casters.h" // from @pybind11_abseil + +namespace tflite { +namespace task { +namespace vision { + +namespace { +namespace py = ::pybind11; + +} // namespace + +PYBIND11_MODULE(image_utils, m) { + // python wrapper for ImageData class which shouldn't be directly used by + // the users. + pybind11::google::ImportStatusModule(); + + py::class_<ImageData>(m, "ImageData", py::buffer_protocol()) + .def(py::init([](py::buffer buffer) { + py::buffer_info info = buffer.request(); + + if (info.ndim != 2 && info.ndim != 3) { + throw py::value_error("Incompatible buffer dimension!"); + } + + int height = info.shape[0]; + int width = info.shape[1]; + int channels = info.ndim == 3 ? info.shape[2] : 1; + + return ImageData{static_cast<uint8*>(info.ptr), width, height, + channels}; + })) + .def_readonly("width", &ImageData::width) + .def_readonly("height", &ImageData::height) + .def_readonly("channels", &ImageData::channels) + .def_buffer([](ImageData& data) -> py::buffer_info { + return py::buffer_info( + data.pixel_data, sizeof(uint8), + py::format_descriptor<uint8>::format(), 3, + {data.height, data.width, data.channels}, + {sizeof(uint8) * size_t(data.width) * size_t(data.channels), + sizeof(uint8) * size_t(data.channels), sizeof(uint8)}); + }); + + m.def("DecodeImageFromFile", &DecodeImageFromFile); + m.def("EncodeImageToPngFile", &EncodeImageToPngFile); + m.def("ImageDataFree", &ImageDataFree); +} + +} // namespace vision +} // namespace task +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/image_classifier.py b/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/image_classifier.py index 9504988e..1450418d 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/image_classifier.py +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/image_classifier.py
@@ -52,6 +52,7 @@ Args: file_path: Path to the model. + Returns: `ImageClassifier` object that's created from the model file. Raises: @@ -70,6 +71,7 @@ Args: options: Options for the image classifier task. + Returns: `ImageClassifier` object that's created from `options`. Raises: @@ -78,7 +80,7 @@ RuntimeError: If other types of error occurred. """ classifier = _CppImageClassifier.create_from_options( - options.base_options, options.classification_options) + options.base_options, options.classification_options.to_pb2()) return cls(options, classifier) def classify( @@ -104,9 +106,12 @@ """ image_data = image_utils.ImageData(image.buffer) if bounding_box is None: - return self._classifier.classify(image_data) - - return self._classifier.classify(image_data, bounding_box) + classification_result = self._classifier.classify(image_data) + else: + classification_result = self._classifier.classify(image_data, + bounding_box.to_pb2()) + return classifications_pb2.ClassificationResult.create_from_pb2( + classification_result) @property def options(self) -> ImageClassifierOptions:
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/image_embedder.py b/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/image_embedder.py index 8e1cc33..299f883 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/image_embedder.py +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/image_embedder.py
@@ -81,8 +81,8 @@ `ImageEmbedderOptions` such as missing the model. RuntimeError: If other types of error occurred. """ - embedder = _CppImageEmbedder.create_from_options(options.base_options, - options.embedding_options) + embedder = _CppImageEmbedder.create_from_options( + options.base_options, options.embedding_options.to_pb2()) return cls(options, embedder) def embed( @@ -107,10 +107,13 @@ RuntimeError: If failed to calculate the embedding vector. """ image_data = image_utils.ImageData(image.buffer) - if bounding_box is None: - return self._embedder.embed(image_data) - return self._embedder.embed(image_data, bounding_box) + if bounding_box is None: + embedding_result = self._embedder.embed(image_data) + else: + embedding_result = self._embedder.embed(image_data, bounding_box.to_pb2()) + + return embedding_pb2.EmbeddingResult.create_from_pb2(embedding_result) def get_embedding_by_index(self, result: embedding_pb2.EmbeddingResult, output_index: int) -> embedding_pb2.Embedding: @@ -130,13 +133,14 @@ """ if output_index < 0 or output_index >= len(result.embeddings): raise ValueError("Output index is out of bound.") - embedding = self._embedder.get_embedding_by_index(result, output_index) - return embedding + embedding = self._embedder.get_embedding_by_index(result.to_pb2(), + output_index) + return embedding_pb2.Embedding.create_from_pb2(embedding) def cosine_similarity(self, u: embedding_pb2.FeatureVector, v: embedding_pb2.FeatureVector) -> float: """Computes cosine similarity [1] between two feature vectors.""" - return self._embedder.cosine_similarity(u, v) + return self._embedder.cosine_similarity(u.to_pb2(), v.to_pb2()) def get_embedding_dimension(self, output_index: int) -> int: """Gets the dimensionality of the embedding output.
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/image_searcher.py b/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/image_searcher.py new file mode 100644 index 0000000..f295bac --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/image_searcher.py
@@ -0,0 +1,142 @@ +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image searcher task.""" + +import dataclasses +from typing import Optional + +from tensorflow_lite_support.python.task.core.proto import base_options_pb2 +from tensorflow_lite_support.python.task.processor.proto import bounding_box_pb2 +from tensorflow_lite_support.python.task.processor.proto import embedding_options_pb2 +from tensorflow_lite_support.python.task.processor.proto import search_options_pb2 +from tensorflow_lite_support.python.task.processor.proto import search_result_pb2 +from tensorflow_lite_support.python.task.vision.core import tensor_image +from tensorflow_lite_support.python.task.vision.core.pybinds import image_utils +from tensorflow_lite_support.python.task.vision.pybinds import _pywrap_image_searcher + +_CppImageSearcher = _pywrap_image_searcher.ImageSearcher +_BaseOptions = base_options_pb2.BaseOptions +_EmbeddingOptions = embedding_options_pb2.EmbeddingOptions +_SearchOptions = search_options_pb2.SearchOptions + + +@dataclasses.dataclass +class ImageSearcherOptions: + """Options for the image search task.""" + base_options: _BaseOptions + embedding_options: _EmbeddingOptions = _EmbeddingOptions() + search_options: _SearchOptions = _SearchOptions() + + +class ImageSearcher(object): + """Class to performs image search. + + It works by performing embedding extraction on images, followed by + nearest-neighbor search in an index of embeddings through ScaNN. + """ + + def __init__(self, options: ImageSearcherOptions, + cpp_searcher: _CppImageSearcher) -> None: + """Initializes the `ImageSearcher` object.""" + # Creates the object of C++ ImageSearcher class. + self._options = options + self._searcher = cpp_searcher + + @classmethod + def create_from_file( + cls, + model_file_path: str, + index_file_path: Optional[str] = None) -> "ImageSearcher": + """Creates the `ImageSearcher` object from a TensorFlow Lite model. + + Args: + model_file_path: Path to the model. + index_file_path: Path to the index. Only required if the index is not + attached to the output tensor metadata as an AssociatedFile with type + SCANN_INDEX_FILE. + + Returns: + `ImageSearcher` object that's created from `options`. + + Raises: + ValueError: If failed to create `ImageSearcher` object from the provided + file such as invalid file. + RuntimeError: If other types of error occurred. + """ + options = ImageSearcherOptions( + base_options=_BaseOptions(file_name=model_file_path), + search_options=_SearchOptions(index_file_name=index_file_path)) + return cls.create_from_options(options) + + @classmethod + def create_from_options(cls, + options: ImageSearcherOptions) -> "ImageSearcher": + """Creates the `ImageSearcher` object from image searcher options. + + Args: + options: Options for the image searcher task. + + Returns: + `ImageSearcher` object that's created from `options`. + Raises: + ValueError: If failed to create `ImageSearcher` object from + `ImageSearcherOptions` such as missing the model. + RuntimeError: If other types of error occurred. + """ + searcher = _CppImageSearcher.create_from_options( + options.base_options, options.embedding_options.to_pb2(), + options.search_options) + return cls(options, searcher) + + def search( + self, + image: tensor_image.TensorImage, + bounding_box: Optional[bounding_box_pb2.BoundingBox] = None + ) -> search_result_pb2.SearchResult: + """Search for image with similar semantic meaning. + + This method performs actual feature extraction on the provided image input, + followed by nearest-neighbor search in the index. + + Args: + image: Tensor image, used to extract the feature vectors. + bounding_box: Bounding box, optional. If set, performed feature vector + extraction only on the provided region of interest. Note that the region + of interest is not clamped, so this method will fail if the region is + out of bounds of the input image. + + Returns: + Search result. + + Raises: + ValueError: If any of the input arguments is invalid. + RuntimeError: If failed to perform nearest-neighbor search. + """ + image_data = image_utils.ImageData(image.buffer) + if bounding_box is None: + return self._searcher.search(image_data) + return self._searcher.search(image_data, bounding_box.to_pb2()) + + def get_user_info(self) -> str: + """Gets the user info stored in the index file. + + Returns: + Opaque user info stored in the index file (if any), in raw binary form. + Returns an empty string if the index doesn't contain user info. + """ + return self._searcher.get_user_info() + + @property + def options(self) -> ImageSearcherOptions: + return self._options
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/image_segmenter.py b/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/image_segmenter.py index ed31c16..45b24fdd 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/image_segmenter.py +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/image_segmenter.py
@@ -14,7 +14,6 @@ """Image segmenter task.""" import dataclasses -from typing import Optional from tensorflow_lite_support.python.task.core.proto import base_options_pb2 from tensorflow_lite_support.python.task.processor.proto import segmentation_options_pb2 @@ -24,6 +23,7 @@ from tensorflow_lite_support.python.task.vision.pybinds import _pywrap_image_segmenter _CppImageSegmenter = _pywrap_image_segmenter.ImageSegmenter +_SegmentationOptions = segmentation_options_pb2.SegmentationOptions _BaseOptions = base_options_pb2.BaseOptions @@ -31,8 +31,7 @@ class ImageSegmenterOptions: """Options for the image segmenter task.""" base_options: _BaseOptions - segmentation_options: Optional[ - segmentation_options_pb2.SegmentationOptions] = None + segmentation_options: _SegmentationOptions = _SegmentationOptions() class ImageSegmenter(object): @@ -77,7 +76,7 @@ RuntimeError: If other types of error occurred. """ segmenter = _CppImageSegmenter.create_from_options( - options.base_options, options.segmentation_options) + options.base_options, options.segmentation_options.to_pb2()) return cls(options, segmenter) def segment( @@ -95,4 +94,6 @@ RuntimeError: If failed to run segmentation. """ image_data = image_utils.ImageData(image.buffer) - return self._segmenter.segment(image_data) + segmentation_result = self._segmenter.segment(image_data) + return segmentations_pb2.SegmentationResult.create_from_pb2( + segmentation_result)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/object_detector.py b/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/object_detector.py index 314d31e3..f21afa8 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/object_detector.py +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/object_detector.py
@@ -79,8 +79,8 @@ `ObjectDetectorOptions` such as missing the model. RuntimeError: If other types of error occurred. """ - detector = _CppObjectDetector.create_from_options(options.base_options, - options.detection_options) + detector = _CppObjectDetector.create_from_options( + options.base_options, options.detection_options.to_pb2()) return cls(options, detector) def detect(self, @@ -98,5 +98,5 @@ RuntimeError: If object detection failed to run. """ image_data = image_utils.ImageData(image.buffer) - - return self._detector.detect(image_data) + detection_result = self._detector.detect(image_data) + return detections_pb2.DetectionResult.create_from_pb2(detection_result)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/BUILD index 9c879520..4b291e6 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/BUILD
@@ -1,4 +1,4 @@ -load("@org_tensorflow//tensorflow:tensorflow.bzl", "pybind_extension") +load("//tensorflow_lite_support/python/task:build_defs.bzl", "pybind_extension_may_pack_coral") package( default_visibility = [ @@ -7,7 +7,7 @@ licenses = ["notice"], # Apache 2.0 ) -pybind_extension( +pybind_extension_may_pack_coral( name = "_pywrap_image_embedder", srcs = [ "_pywrap_image_embedder.cc", @@ -18,14 +18,14 @@ "//tensorflow_lite_support/cc/task/processor/proto:embedding_cc_proto", "//tensorflow_lite_support/cc/task/processor/proto:embedding_options_cc_proto", "//tensorflow_lite_support/cc/task/vision:image_embedder", - "//tensorflow_lite_support/examples/task/vision/desktop/utils:image_utils", + "//tensorflow_lite_support/cc/task/vision/utils:image_utils", "//tensorflow_lite_support/python/task/core/pybinds:task_utils", "@pybind11", "@pybind11_protobuf//pybind11_protobuf:native_proto_caster", ], ) -pybind_extension( +pybind_extension_may_pack_coral( name = "_pywrap_image_classifier", srcs = [ "_pywrap_image_classifier.cc", @@ -36,14 +36,14 @@ "//tensorflow_lite_support/cc/task/processor/proto:classification_options_cc_proto", "//tensorflow_lite_support/cc/task/processor/proto:classifications_cc_proto", "//tensorflow_lite_support/cc/task/vision:image_classifier", - "//tensorflow_lite_support/examples/task/vision/desktop/utils:image_utils", + "//tensorflow_lite_support/cc/task/vision/utils:image_utils", "//tensorflow_lite_support/python/task/core/pybinds:task_utils", "@pybind11", "@pybind11_protobuf//pybind11_protobuf:native_proto_caster", ], ) -pybind_extension( +pybind_extension_may_pack_coral( name = "_pywrap_image_segmenter", srcs = [ "_pywrap_image_segmenter.cc", @@ -52,14 +52,31 @@ deps = [ "//tensorflow_lite_support/cc/task/processor/proto:segmentation_options_cc_proto", "//tensorflow_lite_support/cc/task/vision:image_segmenter", - "//tensorflow_lite_support/examples/task/vision/desktop/utils:image_utils", + "//tensorflow_lite_support/cc/task/vision/utils:image_utils", "//tensorflow_lite_support/python/task/core/pybinds:task_utils", "@pybind11", "@pybind11_protobuf//pybind11_protobuf:native_proto_caster", ], ) -pybind_extension( +pybind_extension_may_pack_coral( + name = "_pywrap_image_searcher", + srcs = [ + "_pywrap_image_searcher.cc", + ], + module_name = "_pywrap_image_searcher", + deps = [ + "//tensorflow_lite_support/cc/task/processor/proto:bounding_box_cc_proto", + "//tensorflow_lite_support/cc/task/vision:image_searcher", + "//tensorflow_lite_support/cc/task/vision/utils:image_utils", + "//tensorflow_lite_support/python/task/core/pybinds:task_utils", + "//tensorflow_lite_support/python/task/processor/proto:search_options_cc_proto", + "@pybind11", + "@pybind11_protobuf//pybind11_protobuf:native_proto_caster", + ], +) + +pybind_extension_may_pack_coral( name = "_pywrap_object_detector", srcs = [ "_pywrap_object_detector.cc", @@ -69,7 +86,7 @@ "//tensorflow_lite_support/cc/task/processor/proto:detection_options_cc_proto", "//tensorflow_lite_support/cc/task/processor/proto:detections_cc_proto", "//tensorflow_lite_support/cc/task/vision:object_detector", - "//tensorflow_lite_support/examples/task/vision/desktop/utils:image_utils", + "//tensorflow_lite_support/cc/task/vision/utils:image_utils", "//tensorflow_lite_support/python/task/core/pybinds:task_utils", "@pybind11", "@pybind11_protobuf//pybind11_protobuf:native_proto_caster",
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_image_classifier.cc b/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_image_classifier.cc new file mode 100644 index 0000000..b4f23baa6 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_image_classifier.cc
@@ -0,0 +1,108 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "pybind11/pybind11.h" +#include "pybind11_protobuf/native_proto_caster.h" // from @pybind11_protobuf +#include "tensorflow_lite_support/cc/task/processor/proto/bounding_box.pb.h" +#include "tensorflow_lite_support/cc/task/processor/proto/classification_options.pb.h" +#include "tensorflow_lite_support/cc/task/processor/proto/classifications.pb.h" +#include "tensorflow_lite_support/cc/task/vision/image_classifier.h" +#include "tensorflow_lite_support/cc/task/vision/utils/image_utils.h" +#include "tensorflow_lite_support/python/task/core/pybinds/task_utils.h" + +namespace tflite { +namespace task { +namespace vision { + +namespace { +namespace py = ::pybind11; +using PythonBaseOptions = ::tflite::python::task::core::BaseOptions; +using CppBaseOptions = ::tflite::task::core::BaseOptions; +} // namespace + +PYBIND11_MODULE(_pywrap_image_classifier, m) { + // python wrapper for C++ ImageClassifier class which shouldn't be directly + // used by the users. + pybind11_protobuf::ImportNativeProtoCasters(); + + py::class_<ImageClassifier>(m, "ImageClassifier") + .def_static( + "create_from_options", + [](const PythonBaseOptions& base_options, + const processor::ClassificationOptions& classification_options) { + ImageClassifierOptions options; + auto cpp_base_options = + core::convert_to_cpp_base_options(base_options); + options.set_allocated_base_options(cpp_base_options.release()); + + if (classification_options.has_display_names_locale()) { + options.set_display_names_locale( + classification_options.display_names_locale()); + } + if (classification_options.has_max_results()) { + options.set_max_results(classification_options.max_results()); + } + if (classification_options.has_score_threshold()) { + options.set_score_threshold( + classification_options.score_threshold()); + } + options.mutable_class_name_whitelist()->CopyFrom( + classification_options.class_name_allowlist()); + options.mutable_class_name_blacklist()->CopyFrom( + classification_options.class_name_denylist()); + + auto classifier = ImageClassifier::CreateFromOptions(options); + return core::get_value(classifier); + }) + .def("classify", + [](ImageClassifier& self, + const ImageData& image_data) -> processor::ClassificationResult { + auto frame_buffer = CreateFrameBufferFromImageData(image_data); + auto vision_classification_result = + self.Classify(*core::get_value(frame_buffer)); + // Convert from vision::ClassificationResult to + // processor::ClassificationResult as required by the Python layer. + processor::ClassificationResult classification_result; + classification_result.ParseFromString( + core::get_value(vision_classification_result) + .SerializeAsString()); + return classification_result; + }) + .def("classify", + [](ImageClassifier& self, const ImageData& image_data, + const processor::BoundingBox& bounding_box) + -> processor::ClassificationResult { + // Convert from processor::BoundingBox to vision::BoundingBox as + // the latter is used in the C++ layer. + BoundingBox vision_bounding_box; + vision_bounding_box.ParseFromString( + bounding_box.SerializeAsString()); + + auto frame_buffer = CreateFrameBufferFromImageData(image_data); + auto vision_classification_result = self.Classify( + *core::get_value(frame_buffer), vision_bounding_box); + // Convert from vision::ClassificationResult to + // processor::ClassificationResult as required by the Python layer. + processor::ClassificationResult classification_result; + classification_result.ParseFromString( + core::get_value(vision_classification_result) + .SerializeAsString()); + return classification_result; + }); +} + +} // namespace vision +} // namespace task +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_image_embedder.cc b/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_image_embedder.cc new file mode 100644 index 0000000..4f3bd58 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_image_embedder.cc
@@ -0,0 +1,145 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <stdexcept> + +#include "pybind11/pybind11.h" +#include "pybind11_protobuf/native_proto_caster.h" // from @pybind11_protobuf +#include "tensorflow_lite_support/cc/task/processor/proto/bounding_box.pb.h" +#include "tensorflow_lite_support/cc/task/processor/proto/embedding.pb.h" +#include "tensorflow_lite_support/cc/task/vision/image_embedder.h" +#include "tensorflow_lite_support/cc/task/vision/utils/image_utils.h" +#include "tensorflow_lite_support/python/task/core/pybinds/task_utils.h" + +namespace tflite { +namespace task { +namespace vision { + +namespace { +namespace py = ::pybind11; +using PythonBaseOptions = ::tflite::python::task::core::BaseOptions; +using CppBaseOptions = ::tflite::task::core::BaseOptions; +} // namespace + +PYBIND11_MODULE(_pywrap_image_embedder, m) { + // python wrapper for C++ ImageEmbeder class which shouldn't be directly used + // by the users. + pybind11_protobuf::ImportNativeProtoCasters(); + + py::class_<ImageEmbedder>(m, "ImageEmbedder") + .def_static( + "create_from_options", + [](const PythonBaseOptions& base_options, + const processor::EmbeddingOptions& embedding_options) { + ImageEmbedderOptions options; + if (base_options.has_file_content()) { + options.mutable_model_file_with_metadata()->set_file_content( + base_options.file_content()); + } + if (base_options.has_file_name()) { + options.mutable_model_file_with_metadata()->set_file_name( + base_options.file_name()); + } + + options.set_num_threads(base_options.num_threads()); + if (base_options.use_coral()) { + options.mutable_compute_settings() + ->mutable_tflite_settings() + ->set_delegate(tflite::proto::Delegate::EDGETPU_CORAL); + } + + if (embedding_options.has_l2_normalize()) { + options.set_l2_normalize(embedding_options.l2_normalize()); + } + if (embedding_options.has_quantize()) { + options.set_quantize(embedding_options.quantize()); + } + auto embedder = ImageEmbedder::CreateFromOptions(options); + return get_value(embedder); + }) + .def("embed", + [](ImageEmbedder& self, + const ImageData& image_data) -> processor::EmbeddingResult { + auto frame_buffer = CreateFrameBufferFromImageData(image_data); + auto vision_embedding_result = + self.Embed(*core::get_value(frame_buffer)); + // Convert from vision::EmbeddingResult to + // processor::EmbeddingResult + processor::EmbeddingResult embedding_result; + embedding_result.ParseFromString( + core::get_value(vision_embedding_result).SerializeAsString()); + return embedding_result; + }) + .def("embed", + [](ImageEmbedder& self, const ImageData& image_data, + const processor::BoundingBox& bounding_box) + -> processor::EmbeddingResult { + // Convert from processor::BoundingBox to vision::BoundingBox as + // the later is used in the C++ layer. + BoundingBox vision_bounding_box; + vision_bounding_box.ParseFromString( + bounding_box.SerializeAsString()); + + auto frame_buffer = CreateFrameBufferFromImageData(image_data); + auto vision_embedding_result = self.Embed( + *core::get_value(frame_buffer), vision_bounding_box); + // Convert from vision::EmbeddingResult to + // processor::EmbeddingResult as required by the Python layer. + processor::EmbeddingResult embedding_result; + embedding_result.ParseFromString( + core::get_value(vision_embedding_result).SerializeAsString()); + return embedding_result; + }) + .def("get_embedding_by_index", + [](ImageEmbedder& self, + const processor::EmbeddingResult& embedding_result, + const int index) -> processor::Embedding { + // Convert from processor::EmbeddingResult to + // vision::EmbeddingResult as the latter is used in the C++ API. + EmbeddingResult vision_embedding_result; + vision_embedding_result.ParseFromString( + embedding_result.SerializeAsString()); + + Embedding vision_embedding{ + self.GetEmbeddingByIndex(vision_embedding_result, index)}; + // Convert from vision::Embedding to processor::Embedding + // as required by the Python layer. + processor::Embedding embedding; + embedding.ParseFromString(vision_embedding.SerializeAsString()); + return embedding; + }) + .def("get_number_of_output_layers", + &ImageEmbedder::GetNumberOfOutputLayers) + .def("get_embedding_dimension", &ImageEmbedder::GetEmbeddingDimension) + .def_static( + "cosine_similarity", + [](const processor::FeatureVector& u, + const processor::FeatureVector& v) -> double { + // Convert from processor::FeatureVector to + // vision::FeatureVector as the latter is used in the C++ + // layer. + FeatureVector vision_feature_vector_u; + vision_feature_vector_u.ParseFromString(u.SerializeAsString()); + FeatureVector vision_feature_vector_v; + vision_feature_vector_v.ParseFromString(v.SerializeAsString()); + auto similarity = ImageEmbedder::CosineSimilarity( + vision_feature_vector_u, vision_feature_vector_v); + return core::get_value(similarity); + }); +} + +} // namespace vision +} // namespace task +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_image_searcher.cc b/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_image_searcher.cc new file mode 100644 index 0000000..bce7e926 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_image_searcher.cc
@@ -0,0 +1,103 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "pybind11/pybind11.h" +#include "pybind11_protobuf/native_proto_caster.h" // from @pybind11_protobuf +#include "tensorflow_lite_support/cc/task/processor/proto/bounding_box.pb.h" +#include "tensorflow_lite_support/cc/task/vision/image_searcher.h" +#include "tensorflow_lite_support/cc/task/vision/utils/image_utils.h" +#include "tensorflow_lite_support/python/task/core/pybinds/task_utils.h" +#include "tensorflow_lite_support/python/task/processor/proto/search_options.pb.h" + +namespace tflite { +namespace task { +namespace vision { + +namespace { +namespace py = ::pybind11; +using PythonBaseOptions = ::tflite::python::task::core::BaseOptions; +using PythonSearchOptions = ::tflite::python::task::processor::SearchOptions; +using CppBaseOptions = ::tflite::task::core::BaseOptions; +using CppEmbeddingOptions = ::tflite::task::processor::EmbeddingOptions; +using CppSearchOptions = ::tflite::task::processor::SearchOptions; +} // namespace + +PYBIND11_MODULE(_pywrap_image_searcher, m) { + // python wrapper for C++ ImageSearcher class which shouldn't be directly used + // by the users. + pybind11_protobuf::ImportNativeProtoCasters(); + + pybind11::class_<ImageSearcher>(m, "ImageSearcher") + .def_static( + "create_from_options", + [](const PythonBaseOptions& base_options, + const processor::EmbeddingOptions& embedding_options, + const PythonSearchOptions& search_options) { + ImageSearcherOptions options; + auto cpp_base_options = + core::convert_to_cpp_base_options(base_options); + options.set_allocated_base_options(cpp_base_options.release()); + + std::unique_ptr<CppEmbeddingOptions> cpp_embedding_options = + std::make_unique<CppEmbeddingOptions>(); + cpp_embedding_options->CopyFrom(embedding_options); + options.set_allocated_embedding_options( + cpp_embedding_options.release()); + + std::unique_ptr<CppSearchOptions> cpp_search_options = + std::make_unique<CppSearchOptions>(); + if (search_options.has_index_file_content()) { + cpp_search_options->mutable_index_file()->set_file_content( + search_options.index_file_content()); + } + if (search_options.has_index_file_name()) { + cpp_search_options->mutable_index_file()->set_file_name( + search_options.index_file_name()); + } + if (search_options.has_max_results()) { + cpp_search_options->set_max_results(search_options.max_results()); + } + + options.set_allocated_search_options(cpp_search_options.release()); + auto searcher = ImageSearcher::CreateFromOptions(options); + return core::get_value(searcher); + }) + .def("search", + [](ImageSearcher& self, + const ImageData& image_data) -> processor::SearchResult { + auto frame_buffer = CreateFrameBufferFromImageData(image_data); + auto search_result = self.Search(*core::get_value(frame_buffer)); + return core::get_value(search_result); + }) + .def("search", + [](ImageSearcher& self, const ImageData& image_data, + const processor::BoundingBox& bounding_box) + -> processor::SearchResult { + // Convert from processor::BoundingBox to vision::BoundingBox as + // the latter is used in the C++ layer. + BoundingBox vision_bounding_box; + vision_bounding_box.ParseFromString( + bounding_box.SerializeAsString()); + + auto frame_buffer = CreateFrameBufferFromImageData(image_data); + auto search_result = self.Search(*core::get_value(frame_buffer), + vision_bounding_box); + return core::get_value(search_result); + }) + .def("get_user_info", [](ImageSearcher& self) -> py::str { + return py::str(self.GetUserInfo()->data()); + }); +} + +} // namespace vision +} // namespace task +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_image_segmenter.cc b/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_image_segmenter.cc new file mode 100644 index 0000000..e71048e --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_image_segmenter.cc
@@ -0,0 +1,73 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "pybind11/pybind11.h" +#include "pybind11_protobuf/native_proto_caster.h" // from @pybind11_protobuf +#include "tensorflow_lite_support/cc/task/processor/proto/segmentation_options.pb.h" +#include "tensorflow_lite_support/cc/task/vision/image_segmenter.h" +#include "tensorflow_lite_support/cc/task/vision/utils/image_utils.h" +#include "tensorflow_lite_support/python/task/core/pybinds/task_utils.h" + +namespace tflite { +namespace task { +namespace vision { + +namespace { +namespace py = ::pybind11; +using PythonBaseOptions = ::tflite::python::task::core::BaseOptions; +using CppBaseOptions = ::tflite::task::core::BaseOptions; +} // namespace + +PYBIND11_MODULE(_pywrap_image_segmenter, m) { + // python wrapper for C++ ImageSegmenter class which shouldn't be directly + // used by the users. + pybind11_protobuf::ImportNativeProtoCasters(); + + py::class_<ImageSegmenter>(m, "ImageSegmenter") + .def_static( + "create_from_options", + [](const PythonBaseOptions& base_options, + const processor::SegmentationOptions& segmentation_options) { + ImageSegmenterOptions options; + auto cpp_base_options = + core::convert_to_cpp_base_options(base_options); + options.set_allocated_base_options(cpp_base_options.release()); + + if (segmentation_options.has_display_names_locale()) { + options.set_display_names_locale( + segmentation_options.display_names_locale()); + } + if (segmentation_options.has_output_type()) { + options.set_output_type( + static_cast<ImageSegmenterOptions::OutputType>( + segmentation_options.output_type())); + } + + auto segmenter = ImageSegmenter::CreateFromOptions(options); + return core::get_value(segmenter); + }) + .def("segment", + [](ImageSegmenter& self, + const ImageData& image_data) -> SegmentationResult { + auto frame_buffer = CreateFrameBufferFromImageData(image_data); + auto vision_segmentation_result = + self.Segment(*core::get_value(frame_buffer)); + return core::get_value(vision_segmentation_result); + }); +} + +} // namespace vision +} // namespace task +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_object_detector.cc b/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_object_detector.cc new file mode 100644 index 0000000..3749efc --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/pybinds/_pywrap_object_detector.cc
@@ -0,0 +1,84 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "pybind11/pybind11.h" +#include "pybind11_protobuf/native_proto_caster.h" // from @pybind11_protobuf +#include "tensorflow_lite_support/cc/task/processor/proto/detection_options.pb.h" +#include "tensorflow_lite_support/cc/task/processor/proto/detections.pb.h" +#include "tensorflow_lite_support/cc/task/vision/object_detector.h" +#include "tensorflow_lite_support/cc/task/vision/utils/image_utils.h" +#include "tensorflow_lite_support/python/task/core/pybinds/task_utils.h" + +namespace tflite { +namespace task { +namespace vision { + +namespace { +namespace py = ::pybind11; +using PythonBaseOptions = ::tflite::python::task::core::BaseOptions; +using CppBaseOptions = ::tflite::task::core::BaseOptions; +} // namespace + +PYBIND11_MODULE(_pywrap_object_detector, m) { + // python wrapper for C++ ObjectDetector class which shouldn't be directly + // used by the users. + pybind11_protobuf::ImportNativeProtoCasters(); + + py::class_<ObjectDetector>(m, "ObjectDetector") + .def_static( + "create_from_options", + [](const PythonBaseOptions& base_options, + const processor::DetectionOptions& detection_options) { + ObjectDetectorOptions options; + auto cpp_base_options = + core::convert_to_cpp_base_options(base_options); + options.set_allocated_base_options(cpp_base_options.release()); + + if (detection_options.has_display_names_locale()) { + options.set_display_names_locale( + detection_options.display_names_locale()); + } + if (detection_options.has_max_results()) { + options.set_max_results(detection_options.max_results()); + } + if (detection_options.has_score_threshold()) { + options.set_score_threshold(detection_options.score_threshold()); + } + options.mutable_class_name_whitelist()->CopyFrom( + detection_options.class_name_allowlist()); + options.mutable_class_name_blacklist()->CopyFrom( + detection_options.class_name_denylist()); + + auto detector = ObjectDetector::CreateFromOptions(options); + return core::get_value(detector); + }) + .def("detect", + [](ObjectDetector& self, + const ImageData& image_data) -> processor::DetectionResult { + auto frame_buffer = CreateFrameBufferFromImageData(image_data); + auto vision_detection_result = + self.Detect(*core::get_value(frame_buffer)); + // Convert from vision::DetectionResult to + // processor::DetectionResult as required by the Python layer. + processor::DetectionResult detection_result; + detection_result.ParseFromString( + core::get_value(vision_detection_result).SerializeAsString()); + return detection_result; + }); +} + +} // namespace vision +} // namespace task +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/test/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/python/test/BUILD index cf5d45a..52b142e0 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/python/test/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/test/BUILD
@@ -6,13 +6,6 @@ ) py_library( - name = "base_test", - testonly = 1, - srcs = ["base_test.py"], - srcs_version = "PY3", -) - -py_library( name = "test_util", testonly = 1, srcs = ["test_util.py"],
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/test/base_test.py b/third_party/tflite_support/src/tensorflow_lite_support/python/test/base_test.py deleted file mode 100644 index edb42d0..0000000 --- a/third_party/tflite_support/src/tensorflow_lite_support/python/test/base_test.py +++ /dev/null
@@ -1,50 +0,0 @@ -# Copyright 2022 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Base TestCase for the unit tests.""" - -import unittest - -__unittest = True # Allows shorter stack trace for .assertDeepAlmostEqual pylint: disable=invalid-name - - -class BaseTestCase(unittest.TestCase): - """Base test case.""" - - def assertDeepAlmostEqual(self, expected, actual, **kwargs): - """Compares lists, dicts and tuples recursively. - - Checks numeric values using test_case's - :py:meth:`unittest.TestCase.assertAlmostEqual` and checks all other values - with :py:meth:`unittest.TestCase.assertEqual`. Accepts additional keyword - arguments and pass those intact to assertAlmostEqual() (that's how you - specify comparison precision). - - Args: - expected: Expected object. - actual: Actual object. - **kwargs: Other parameters to be passed. - """ - if isinstance(expected, (int, float, complex)): - self.assertAlmostEqual(expected, actual, **kwargs) - elif isinstance(expected, (list, tuple)): - self.assertEqual(len(expected), len(actual)) - for index in range(len(expected)): - v1, v2 = expected[index], actual[index] - self.assertDeepAlmostEqual(v1, v2, **kwargs) - elif isinstance(expected, dict): - self.assertEqual(set(expected), set(actual)) - for key in expected: - self.assertDeepAlmostEqual(expected[key], actual[key], **kwargs) - else: - self.assertEqual(expected, actual)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/audio/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/audio/BUILD index 1114e87..6556a00 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/audio/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/audio/BUILD
@@ -13,12 +13,13 @@ "//tensorflow_lite_support/cc/test/testdata/task/audio:test_models", ], deps = [ + # build rule placeholder: numpy dep, + # build rule placeholder: tensorflow dep, "//tensorflow_lite_support/python/task/audio:audio_embedder", "//tensorflow_lite_support/python/task/audio/core:audio_record", "//tensorflow_lite_support/python/task/audio/core:tensor_audio", "//tensorflow_lite_support/python/task/core/proto:base_options_py_pb2", "//tensorflow_lite_support/python/task/processor/proto:embedding_options_pb2", - "//tensorflow_lite_support/python/test:base_test", "//tensorflow_lite_support/python/test:test_util", "@absl_py//absl/testing:parameterized", ], @@ -32,6 +33,7 @@ "//tensorflow_lite_support/cc/test/testdata/task/audio:test_models", ], deps = [ + # build rule placeholder: tensorflow dep, "//tensorflow_lite_support/python/task/audio:audio_classifier", "//tensorflow_lite_support/python/task/audio/core:audio_record", "//tensorflow_lite_support/python/task/audio/core:tensor_audio", @@ -39,9 +41,7 @@ "//tensorflow_lite_support/python/task/processor/proto:class_pb2", "//tensorflow_lite_support/python/task/processor/proto:classification_options_pb2", "//tensorflow_lite_support/python/task/processor/proto:classifications_pb2", - "//tensorflow_lite_support/python/test:base_test", "//tensorflow_lite_support/python/test:test_util", "@absl_py//absl/testing:parameterized", - "@com_google_protobuf//:protobuf_python", ], )
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/audio/audio_classifier_test.py b/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/audio/audio_classifier_test.py index 5173fa8..1b03f25 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/audio/audio_classifier_test.py +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/audio/audio_classifier_test.py
@@ -14,25 +14,18 @@ """Tests for audio_classifier.""" import enum -import json from absl.testing import parameterized +import tensorflow as tf -from google.protobuf import json_format import unittest from tensorflow_lite_support.python.task.audio import audio_classifier from tensorflow_lite_support.python.task.audio.core import audio_record from tensorflow_lite_support.python.task.audio.core import tensor_audio from tensorflow_lite_support.python.task.core.proto import base_options_pb2 -from tensorflow_lite_support.python.task.processor.proto import class_pb2 from tensorflow_lite_support.python.task.processor.proto import classification_options_pb2 -from tensorflow_lite_support.python.task.processor.proto import classifications_pb2 -from tensorflow_lite_support.python.test import base_test from tensorflow_lite_support.python.test import test_util -# TODO(b/220067158): Change to import tensorflow and leverage tf.test once -# fixed the dependency issue. - _mock = unittest.mock _BaseOptions = base_options_pb2.BaseOptions _AudioClassifier = audio_classifier.AudioClassifier @@ -40,58 +33,84 @@ _FIXED_INPUT_SIZE_MODEL_FILE = 'yamnet_audio_classifier_with_metadata.tflite' _SPEECH_AUDIO_FILE = 'speech.wav' -_FIXED_INPUT_SIZE_MODEL_CLASSIFICATIONS = { - 'scores': [{ - 'index': 0, - 'score': 0.91796875, - 'class_name': 'Speech' - }, { - 'index': 500, - 'score': 0.05859375, - 'class_name': 'Inside, small room' - }, { - 'index': 494, - 'score': 0.01367188, - 'class_name': 'Silence' - }] +_FIXED_INPUT_SIZE_MODEL_CLASSIFICATIONS = """ +classifications { + classes { + index: 0 + score: 0.917969 + display_name: "" + class_name: "Speech" + } + classes { + index: 500 + score: 0.058594 + display_name: "" + class_name: "Inside, small room" + } + classes { + index: 494 + score: 0.011719 + display_name: "" + class_name: "Silence" + } + head_index: 0 + head_name: "scores" } +""" _MULTIHEAD_MODEL_FILE = 'two_heads.tflite' _TWO_HEADS_AUDIO_FILE = 'two_heads.wav' -_MULTIHEAD_MODEL_CLASSIFICATIONS = { - 'yamnet_classification': [{ - 'index': 508, - 'score': 0.5486158, - 'class_name': 'Environmental noise' - }, { - 'index': 507, - 'score': 0.38086897, - 'class_name': 'Noise' - }, { - 'index': 106, - 'score': 0.25613675, - 'class_name': 'Bird' - }], - 'bird_classification': [{ - 'index': 4, - 'score': 0.93399656, - 'class_name': 'Chestnut-crowned Antpitta' - }, { - 'index': 1, - 'score': 0.065934494, - 'class_name': 'White-breasted Wood-Wren' - }, { - 'index': 0, - 'score': 6.1469495e-05, - 'class_name': 'Red Crossbill' - }] +_MULTIHEAD_MODEL_CLASSIFICATIONS = """ +classifications { + classes { + index: 508 + score: 0.548616 + display_name: "" + class_name: "Environmental noise" + } + classes { + index: 507 + score: 0.380869 + display_name: "" + class_name: "Noise" + } + classes { + index: 106 + score: 0.256137 + display_name: "" + class_name: "Bird" + } + head_index: 0 + head_name: "yamnet_classification" } +classifications { + classes { + index: 4 + score: 0.933997 + display_name: "" + class_name: "Chestnut-crowned Antpitta" + } + classes { + index: 1 + score: 0.065934 + display_name: "" + class_name: "White-breasted Wood-Wren" + } + classes { + index: 0 + score: 6.1469495e-05 + display_name: "" + class_name: "Red Crossbill" + } + head_index: 1 + head_name: "bird_classification" +} +""" _ALLOW_LIST = ['Speech', 'Inside, small room'] _DENY_LIST = ['Speech'] _SCORE_THRESHOLD = 0.5 _MAX_RESULTS = 3 -_ACCEPTABLE_ERROR_RANGE = 0.005 class ModelFileType(enum.Enum): @@ -108,22 +127,7 @@ return classifier -def _build_test_data(classifications): - expected_result = classifications_pb2.ClassificationResult() - - for index, (head_name, categories) in enumerate(classifications.items()): - classifications = classifications_pb2.Classifications( - head_index=index, head_name=head_name) - classifications.classes.extend( - [class_pb2.Category(**args) for args in categories]) - expected_result.classifications.append(classifications) - - expected_result_dict = json.loads(json_format.MessageToJson(expected_result)) - - return expected_result_dict - - -class AudioClassifierTest(parameterized.TestCase, base_test.BaseTestCase): +class AudioClassifierTest(parameterized.TestCase, tf.test.TestCase): def setUp(self): super().setUp() @@ -195,7 +199,7 @@ (_MULTIHEAD_MODEL_FILE, ModelFileType.FILE_CONTENT, _TWO_HEADS_AUDIO_FILE, 3, _MULTIHEAD_MODEL_CLASSIFICATIONS)) def test_classify_model(self, model_name, model_file_type, audio_file_name, - max_results, expected_classifications): + max_results, expected_result_text_proto): # Creates classifier. model_path = test_util.get_test_data_path(model_name) if model_file_type is ModelFileType.FILE_NAME: @@ -218,14 +222,9 @@ # Classifies the input. audio_result = classifier.classify(tensor) - audio_result_dict = json.loads(json_format.MessageToJson(audio_result)) - - # Builds test data. - expected_result_dict = _build_test_data(expected_classifications) # Comparing results. - self.assertDeepAlmostEqual( - audio_result_dict, expected_result_dict, delta=_ACCEPTABLE_ERROR_RANGE) + self.assertProtoEquals(expected_result_text_proto, audio_result.to_pb2()) def test_max_results_option(self): # Creates classifier. @@ -240,9 +239,7 @@ # Classifies the input. audio_result = classifier.classify(tensor) - audio_result_dict = json.loads(json_format.MessageToJson(audio_result)) - - categories = audio_result_dict['classifications'][0]['classes'] + categories = audio_result.classifications[0].categories self.assertLessEqual( len(categories), _MAX_RESULTS, 'Too many results returned.') @@ -260,14 +257,11 @@ # Classifies the input. audio_result = classifier.classify(tensor) - audio_result_dict = json.loads(json_format.MessageToJson(audio_result)) - - categories = audio_result_dict['classifications'][0]['classes'] + categories = audio_result.classifications[0].categories for category in categories: - score = category['score'] self.assertGreaterEqual( - score, _SCORE_THRESHOLD, + category.score, _SCORE_THRESHOLD, 'Classification with score lower than threshold found. {0}'.format( category)) @@ -276,7 +270,7 @@ base_options = _BaseOptions(file_name=self.model_path) classifier = _create_classifier_from_options( - base_options, class_name_allowlist=_ALLOW_LIST) + base_options, category_name_allowlist=_ALLOW_LIST) # Load the input audio file. tensor = tensor_audio.TensorAudio.create_from_wav_file( @@ -284,12 +278,10 @@ # Classifies the input. audio_result = classifier.classify(tensor) - audio_result_dict = json.loads(json_format.MessageToJson(audio_result)) - - categories = audio_result_dict['classifications'][0]['classes'] + categories = audio_result.classifications[0].categories for category in categories: - label = category['className'] + label = category.category_name self.assertIn( label, _ALLOW_LIST, 'Label "{0}" found but not in label allow list'.format(label)) @@ -299,7 +291,7 @@ base_options = _BaseOptions(file_name=self.model_path) classifier = _create_classifier_from_options( - base_options, score_threshold=0.01, class_name_denylist=_DENY_LIST) + base_options, score_threshold=0.01, category_name_denylist=_DENY_LIST) # Load the input audio file. tensor = tensor_audio.TensorAudio.create_from_wav_file( @@ -307,12 +299,10 @@ # Classifies the input. audio_result = classifier.classify(tensor) - audio_result_dict = json.loads(json_format.MessageToJson(audio_result)) - - categories = audio_result_dict['classifications'][0]['classes'] + categories = audio_result.classifications[0].categories for category in categories: - label = category['className'] + label = category.category_name self.assertNotIn(label, _DENY_LIST, 'Label "{0}" found but in deny list.'.format(label)) @@ -324,7 +314,7 @@ r'exclusive options.'): base_options = _BaseOptions(file_name=self.model_path) classification_options = classification_options_pb2.ClassificationOptions( - class_name_allowlist=['foo'], class_name_denylist=['bar']) + category_name_allowlist=['foo'], category_name_denylist=['bar']) options = _AudioClassifierOptions( base_options=base_options, classification_options=classification_options) @@ -332,4 +322,4 @@ if __name__ == '__main__': - unittest.main() + tf.test.main()
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/audio/audio_embedder_test.py b/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/audio/audio_embedder_test.py index e3f08d10..6261199 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/audio/audio_embedder_test.py +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/audio/audio_embedder_test.py
@@ -16,16 +16,15 @@ import enum from absl.testing import parameterized -# TODO(b/220067158): Change to import tensorflow and leverage tf.test once -# fixed the dependency issue. -import unittest +import numpy as np +import tensorflow as tf +import unittest from tensorflow_lite_support.python.task.audio import audio_embedder from tensorflow_lite_support.python.task.audio.core import audio_record from tensorflow_lite_support.python.task.audio.core import tensor_audio from tensorflow_lite_support.python.task.core.proto import base_options_pb2 from tensorflow_lite_support.python.task.processor.proto import embedding_options_pb2 -from tensorflow_lite_support.python.test import base_test from tensorflow_lite_support.python.test import test_util _mock = unittest.mock @@ -41,7 +40,7 @@ FILE_NAME = 2 -class AudioEmbedderTest(parameterized.TestCase, base_test.BaseTestCase): +class AudioEmbedderTest(parameterized.TestCase, tf.test.TestCase): def setUp(self): super().setUp() @@ -100,11 +99,11 @@ self.assertEqual(record.buffer_size, 15600) @parameterized.parameters((_YAMNET_EMBEDDING_MODEL_FILE, False, False, - ModelFileType.FILE_NAME, 1024, 0.091439), + ModelFileType.FILE_NAME, 1024, 0.091439, 0), (_YAMNET_EMBEDDING_MODEL_FILE, True, True, - ModelFileType.FILE_CONTENT, 1024, 0.092382)) + ModelFileType.FILE_CONTENT, 1024, 0.092382, 0)) def test_embed(self, model_name, l2_normalize, quantize, model_file_type, - embedding_length, expected_similarity): + embedding_length, expected_similarity, expected_first_value): # Create embedder. model_path = test_util.get_test_data_path(model_name) if model_file_type is ModelFileType.FILE_NAME: @@ -137,26 +136,25 @@ result1 = embedder.embed(tensor1) # Check embedding sizes. - def _check_embedding_size(result): - self.assertLen(result.embeddings, 1) - feature_vector = result.embeddings[0].feature_vector - if quantize: - self.assertLen(feature_vector.value_string, embedding_length) - else: - self.assertLen(feature_vector.value_float, embedding_length) - - _check_embedding_size(result0) - _check_embedding_size(result1) - + self.assertLen(result0.embeddings, 1) result0_feature_vector = result0.embeddings[0].feature_vector + self.assertLen(result1.embeddings, 1) result1_feature_vector = result1.embeddings[0].feature_vector + self.assertLen(result0_feature_vector.value, embedding_length) + self.assertLen(result1_feature_vector.value, embedding_length) + if quantize: - self.assertLen(result0_feature_vector.value_string, 1024) - self.assertLen(result1_feature_vector.value_string, 1024) + self.assertEqual(result0_feature_vector.value.dtype, np.uint8) else: - self.assertLen(result0_feature_vector.value_float, 1024) - self.assertLen(result1_feature_vector.value_float, 1024) + self.assertEqual(result1_feature_vector.value.dtype, float) + + self.assertLen(result0_feature_vector.value, 1024) + self.assertLen(result1_feature_vector.value, 1024) + + # Check embedding value. + self.assertAlmostEqual(result0_feature_vector.value[0], + expected_first_value) # Checks cosine similarity. similarity = embedder.cosine_similarity(result0_feature_vector, @@ -176,4 +174,4 @@ if __name__ == "__main__": - unittest.main() + tf.test.main()
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/audio/core/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/audio/core/BUILD index db248ed..ad10a79 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/audio/core/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/audio/core/BUILD
@@ -10,6 +10,7 @@ srcs = ["audio_record_test.py"], deps = [ # build rule placeholder: numpy dep, + # build rule placeholder: tensorflow dep, "//tensorflow_lite_support/python/task/audio/core:audio_record", ], ) @@ -22,6 +23,7 @@ ], deps = [ # build rule placeholder: numpy dep, + # build rule placeholder: tensorflow dep, "//tensorflow_lite_support/python/task/audio/core:audio_record", "//tensorflow_lite_support/python/task/audio/core:tensor_audio", "//tensorflow_lite_support/python/task/audio/core/pybinds:_pywrap_audio_buffer",
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/audio/core/audio_record_test.py b/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/audio/core/audio_record_test.py index 00023557..3f40779 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/audio/core/audio_record_test.py +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/audio/core/audio_record_test.py
@@ -13,27 +13,26 @@ # limitations under the License. """Tests for audio_record.""" -from unittest import mock - import numpy as np -from numpy import testing - +import tensorflow as tf import unittest from tensorflow_lite_support.python.task.audio.core import audio_record +_mock = unittest.mock + _CHANNELS = 2 _SAMPLING_RATE = 16000 _BUFFER_SIZE = 15600 -class AudioRecordTest(unittest.TestCase): +class AudioRecordTest(tf.test.TestCase): def setUp(self): super().setUp() # Mock sounddevice.InputStream - with mock.patch("sounddevice.InputStream") as mock_input_stream_new_method: - self.mock_input_stream = mock.MagicMock() + with _mock.patch("sounddevice.InputStream") as mock_input_stream_new_method: + self.mock_input_stream = _mock.MagicMock() mock_input_stream_new_method.return_value = self.mock_input_stream self.record = audio_record.AudioRecord(_CHANNELS, _SAMPLING_RATE, _BUFFER_SIZE) @@ -72,13 +71,13 @@ # Assert read data of a single chunk. recorded_audio_data = self.record.read(chunk_size) - testing.assert_almost_equal(recorded_audio_data, input_data[-1]) + self.assertAllClose(recorded_audio_data, input_data[-1]) # Assert read all data in buffer. recorded_audio_data = self.record.read(chunk_size * 2) print(input_data[-2].shape) expected_data = np.concatenate(input_data[-2:]) - testing.assert_almost_equal(recorded_audio_data, expected_data) + self.assertAllClose(recorded_audio_data, expected_data) def test_read_fails_with_invalid_sample_size(self): callback_fn = self.init_args["callback"] @@ -93,4 +92,4 @@ if __name__ == "__main__": - unittest.main() + tf.test.main()
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/audio/core/tensor_audio_test.py b/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/audio/core/tensor_audio_test.py index a7459618c..94668c7 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/audio/core/tensor_audio_test.py +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/audio/core/tensor_audio_test.py
@@ -14,7 +14,7 @@ """Tests for tensor_audio.""" from absl.testing import parameterized import numpy as np -from numpy import testing +import tensorflow as tf import unittest from tensorflow_lite_support.python.task.audio.core import audio_record @@ -30,7 +30,7 @@ _BUFFER_SIZE = 15600 -class TensorAudioTest(parameterized.TestCase, unittest.TestCase): +class TensorAudioTest(parameterized.TestCase, tf.test.TestCase): def setUp(self): super().setUp() @@ -69,7 +69,7 @@ self.assertEqual(audio_format.sample_rate, _SAMPLE_RATE) self.assertEqual(self.test_tensor_audio.buffer_size, _BUFFER_SIZE) self.assertIsInstance(audio_buffer, np.ndarray) - testing.assert_almost_equal(audio_buffer, array) + self.assertAllClose(audio_buffer, array) def test_load_from_array_succeeds_with_larger_input_size_and_default_params( self): @@ -84,7 +84,7 @@ self.assertEqual(audio_format.sample_rate, _SAMPLE_RATE) self.assertEqual(self.test_tensor_audio.buffer_size, _BUFFER_SIZE) self.assertIsInstance(audio_buffer, np.ndarray) - testing.assert_almost_equal(audio_buffer, array[_BUFFER_SIZE:]) + self.assertAllClose(audio_buffer, array[_BUFFER_SIZE:]) @parameterized.parameters((0, 15600), (7800, 15600)) def test_load_from_array_succeeds_with_larger_input_size_and_params_specified( @@ -100,7 +100,7 @@ self.assertEqual(audio_format.sample_rate, _SAMPLE_RATE) self.assertEqual(self.test_tensor_audio.buffer_size, _BUFFER_SIZE) self.assertIsInstance(audio_buffer, np.ndarray) - testing.assert_almost_equal(audio_buffer, array[offset:offset + size]) + self.assertAllClose(audio_buffer, array[offset:offset + size]) def test_load_from_array_succeeds_with_smaller_input_size_and_default_params( self): @@ -116,7 +116,7 @@ self.assertEqual(audio_format.sample_rate, _SAMPLE_RATE) self.assertEqual(self.test_tensor_audio.buffer_size, _BUFFER_SIZE) self.assertIsInstance(audio_buffer, np.ndarray) - testing.assert_almost_equal(audio_buffer[-input_length:], array) + self.assertAllClose(audio_buffer[-input_length:], array) @parameterized.parameters((0, 4000), (3900, 100)) def test_load_from_array_succeeds_with_smaller_input_size_and_params_specified( @@ -132,8 +132,7 @@ self.assertEqual(audio_format.sample_rate, _SAMPLE_RATE) self.assertEqual(self.test_tensor_audio.buffer_size, _BUFFER_SIZE) self.assertIsInstance(audio_buffer, np.ndarray) - testing.assert_almost_equal(audio_buffer[-size:], - array[offset:offset + size]) + self.assertAllClose(audio_buffer[-size:], array[offset:offset + size]) @parameterized.parameters((7800, 15600), (0, 20000)) def test_load_from_array_fails_with_invalid_offset_size(self, offset, size): @@ -177,7 +176,7 @@ self.test_tensor_audio.load_from_audio_record(record) # Assert read all data in the float buffer. - testing.assert_almost_equal(self.test_tensor_audio.buffer, expected_data) + self.assertAllClose(self.test_tensor_audio.buffer, expected_data) @_mock.patch("sounddevice.InputStream", return_value=_mock.MagicMock()) def test_load_from_audio_record_fails_with_invalid_buffer_size(self, _): @@ -213,4 +212,4 @@ if __name__ == "__main__": - unittest.main() + tf.test.main()
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/text/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/text/BUILD index 38708c5..e71b9e8 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/text/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/text/BUILD
@@ -11,12 +11,35 @@ data = [ "//tensorflow_lite_support/cc/test/testdata/task/text:mobilebert_embedding_with_metadata", "//tensorflow_lite_support/cc/test/testdata/task/text:regex_embedding_with_metadata", + "//tensorflow_lite_support/cc/test/testdata/task/text:universal_sentence_encoder_qa", + ], + deps = [ + # build rule placeholder: numpy dep, + # build rule placeholder: tensorflow dep, + "//tensorflow_lite_support/python/task/core/proto:base_options_py_pb2", + "//tensorflow_lite_support/python/task/processor/proto:embedding_options_pb2", + "//tensorflow_lite_support/python/task/text:text_embedder", + "//tensorflow_lite_support/python/test:test_util", + "@absl_py//absl/testing:parameterized", + ], +) + +py_test( + name = "text_searcher_test", + srcs = ["text_searcher_test.py"], + data = [ + "//tensorflow_lite_support/cc/test/testdata/task/text:mobilebert_embedding_with_metadata", + "//tensorflow_lite_support/cc/test/testdata/task/text:regex_embedding_with_metadata", + "//tensorflow_lite_support/cc/test/testdata/task/text:test_indices", + "//tensorflow_lite_support/cc/test/testdata/task/text:test_searchers", + "//tensorflow_lite_support/cc/test/testdata/task/text:universal_sentence_encoder_qa", ], deps = [ # build rule placeholder: tensorflow dep, "//tensorflow_lite_support/python/task/core/proto:base_options_py_pb2", "//tensorflow_lite_support/python/task/processor/proto:embedding_options_pb2", - "//tensorflow_lite_support/python/task/text:text_embedder", + "//tensorflow_lite_support/python/task/processor/proto:search_options_py_pb2", + "//tensorflow_lite_support/python/task/text:text_searcher", "//tensorflow_lite_support/python/test:test_util", "@absl_py//absl/testing:parameterized", ],
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/text/text_embedder_test.py b/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/text/text_embedder_test.py index 5393cb19..044b89ce 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/text/text_embedder_test.py +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/text/text_embedder_test.py
@@ -16,8 +16,9 @@ import enum from absl.testing import parameterized - +import numpy as np import tensorflow as tf + from tensorflow_lite_support.python.task.core.proto import base_options_pb2 from tensorflow_lite_support.python.task.processor.proto import embedding_options_pb2 from tensorflow_lite_support.python.task.text import text_embedder @@ -29,6 +30,7 @@ _REGEX_MODEL = "regex_one_embedding_with_metadata.tflite" _BERT_MODEL = "mobilebert_embedding_with_metadata.tflite" +_USE_MODEL = "universal_sentence_encoder_qa_with_metadata.tflite" class ModelFileType(enum.Enum): @@ -69,13 +71,18 @@ self.assertIsInstance(embedder, _TextEmbedder) @parameterized.parameters( - (_REGEX_MODEL, False, False, ModelFileType.FILE_NAME, 16, 0.999937), - (_REGEX_MODEL, True, True, ModelFileType.FILE_NAME, 16, 0.999878), - (_BERT_MODEL, False, False, ModelFileType.FILE_CONTENT, 512, 0.969514), - (_BERT_MODEL, True, True, ModelFileType.FILE_CONTENT, 512, 0.966984), + (_REGEX_MODEL, False, False, ModelFileType.FILE_NAME, 16, 0.999937, + 0.03093561), + (_REGEX_MODEL, True, True, ModelFileType.FILE_NAME, 16, 0.999878, 70), + (_BERT_MODEL, False, False, ModelFileType.FILE_CONTENT, 512, 0.969514, + 19.901617), + (_BERT_MODEL, True, True, ModelFileType.FILE_CONTENT, 512, 0.966984, 7), + (_USE_MODEL, False, False, ModelFileType.FILE_NAME, 100, 0.851961, + 1.4229515), + (_USE_MODEL, True, True, ModelFileType.FILE_CONTENT, 100, 0.852664, 16), ) def test_embed(self, model_name, l2_normalize, quantize, model_file_type, - embedding_length, expected_similarity): + embedding_length, expected_similarity, expected_first_value): # Create embedder. model_path = test_util.get_test_data_path(model_name) if model_file_type is ModelFileType.FILE_NAME: @@ -99,16 +106,28 @@ result1 = embedder.embed("what a great and fantastic trip") # Check embedding sizes. - def _check_embedding_size(result): - self.assertLen(result.embeddings, 1) - feature_vector = result.embeddings[0].feature_vector - if quantize: - self.assertLen(feature_vector.value_string, embedding_length) - else: - self.assertLen(feature_vector.value_float, embedding_length) + self.assertLen(result0.embeddings, 1) + result0_feature_vector = result0.embeddings[0].feature_vector + self.assertLen(result1.embeddings, 1) + result1_feature_vector = result1.embeddings[0].feature_vector - _check_embedding_size(result0) - _check_embedding_size(result1) + self.assertLen(result0_feature_vector.value, embedding_length) + self.assertLen(result1_feature_vector.value, embedding_length) + + if quantize: + self.assertEqual(result0_feature_vector.value.dtype, np.uint8) + else: + self.assertEqual(result1_feature_vector.value.dtype, float) + + # Check embedding value. + self.assertAlmostEqual( + result0_feature_vector.value[0], expected_first_value, places=3) + + # Checks cosine similarity. + similarity = embedder.cosine_similarity( + result0.embeddings[0].feature_vector, + result1.embeddings[0].feature_vector) + self.assertAlmostEqual(similarity, expected_similarity, places=4) def test_get_embedding_dimension(self): options = _TextEmbedderOptions(_BaseOptions(file_name=self.model_path))
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/text/text_searcher_test.py b/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/text/text_searcher_test.py new file mode 100644 index 0000000..0547d4e --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/text/text_searcher_test.py
@@ -0,0 +1,366 @@ +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for text_searcher.""" + +import enum + +from absl.testing import parameterized + +import tensorflow as tf +from tensorflow_lite_support.python.task.core.proto import base_options_pb2 +from tensorflow_lite_support.python.task.processor.proto import embedding_options_pb2 +from tensorflow_lite_support.python.task.processor.proto import search_options_pb2 +from tensorflow_lite_support.python.task.text import text_searcher +from tensorflow_lite_support.python.test import test_util + +_BaseOptions = base_options_pb2.BaseOptions +_EmbeddingOptions = embedding_options_pb2.EmbeddingOptions +_SearchOptions = search_options_pb2.SearchOptions +_TextSearcher = text_searcher.TextSearcher +_TextSearcherOptions = text_searcher.TextSearcherOptions + +_REGEX_EMBEDDER_MODEL = 'regex_one_embedding_with_metadata.tflite' +_REGEX_SEARCHER_MODEL = 'regex_searcher.tflite' +_REGEX_INDEX = 'regex_index.ldb' +_EXPECTED_REGEX_SEARCH_RESULT = """ +nearest_neighbors { + metadata: "The weather was excellent." + distance: 0.0 +} +nearest_neighbors { + metadata: "The sun was shining on that day." + distance: 5.7e-5 +} +nearest_neighbors { + metadata: "The cat is chasing after the mouse." + distance: 8.9e-5 +} +nearest_neighbors { + metadata: "It was a sunny day." + distance: 0.000113 +} +nearest_neighbors { + metadata: "He was very happy with his newly bought car." + distance: 0.000119 +} +""" +_EXPECTED_REGEX_DEFAULT_OPTIONS_SEARCH_RESULT = """ +nearest_neighbors { + metadata: "The weather was excellent." + distance: 0.889665 +} +nearest_neighbors { + metadata: "The sun was shining on that day." + distance: 0.889668 +} +nearest_neighbors { + metadata: "The cat is chasing after the mouse." + distance: 0.88967 +} +nearest_neighbors { + metadata: "It was a sunny day." + distance: 0.889671 +} +nearest_neighbors { + metadata: "He was very happy with his newly bought car." + distance: 0.889672 +} +""" + +_BERT_EMBEDDER_MODEL = 'mobilebert_embedding_with_metadata.tflite' +_BERT_SEARCHER_MODEL = 'mobilebert_searcher.tflite' +_BERT_INDEX = 'mobilebert_index.ldb' +_EXPECTED_BERT_SEARCH_RESULT = """ +nearest_neighbors { + metadata: "The weather was excellent." + distance: 0.0 +} +nearest_neighbors { + metadata: "It was a sunny day." + distance: 0.115369 +} +nearest_neighbors { + metadata: "The sun was shining on that day." + distance: 0.230017 +} +nearest_neighbors { + metadata: "He was very happy with his newly bought car." + distance: 0.324563 +} +nearest_neighbors { + metadata: "The cat is chasing after the mouse." + distance: 0.966928 +} +""" + +_USE_EMBEDDER_MODEL = 'universal_sentence_encoder_qa_with_metadata.tflite' +_USE_SEARCHER_MODEL = 'universal_sentence_encoder_searcher.tflite' +_USE_INDEX = 'universal_sentence_encoder_index.ldb' +_EXPECTED_USE_SEARCH_RESULT = """ +nearest_neighbors { + metadata: "The weather was excellent." + distance: 0.0 +} +nearest_neighbors { + metadata: "It was a sunny day." + distance: 0.146359 +} +nearest_neighbors { + metadata: "The sun was shining on that day." + distance: 0.152225 +} +nearest_neighbors { + metadata: "The cat is chasing after the mouse." + distance: 0.359965 +} +nearest_neighbors { + metadata: "He was very happy with his newly bought car." + distance: 0.366927 +} +""" + +_MAX_RESULTS = 2 + + +class ModelFileType(enum.Enum): + FILE_CONTENT = 1 + FILE_NAME = 2 + + +class IndexFileType(enum.Enum): + NONE = 1 + FILE_CONTENT = 2 + FILE_NAME = 3 + + +class TextSearcherTest(parameterized.TestCase, tf.test.TestCase): + + def setUp(self): + super().setUp() + self.embedder_model_path = test_util.get_test_data_path( + _REGEX_EMBEDDER_MODEL) + self.searcher_model_path = test_util.get_test_data_path( + _REGEX_SEARCHER_MODEL) + self.index_path = test_util.get_test_data_path(_REGEX_INDEX) + + def test_create_from_file_succeeds_with_valid_embedder_and_index_paths(self): + # Creates with default option and valid model and index files successfully. + searcher = _TextSearcher.create_from_file(self.embedder_model_path, + self.index_path) + self.assertIsInstance(searcher, _TextSearcher) + + def test_create_from_file_succeeds_with_valid_searcher_path(self): + # Creates with default option and valid model and index files successfully. + searcher = _TextSearcher.create_from_file(self.searcher_model_path) + self.assertIsInstance(searcher, _TextSearcher) + + def test_create_from_options_succeeds_with_valid_embedder_and_index_paths( + self): + options = _TextSearcherOptions( + base_options=_BaseOptions(file_name=self.embedder_model_path), + search_options=_SearchOptions(index_file_name=self.index_path)) + searcher = _TextSearcher.create_from_options(options) + self.assertIsInstance(searcher, _TextSearcher) + + def test_create_from_options_succeeds_with_valid_searcher_path(self): + options = _TextSearcherOptions( + base_options=_BaseOptions(file_name=self.searcher_model_path), + search_options=_SearchOptions()) + searcher = _TextSearcher.create_from_options(options) + self.assertIsInstance(searcher, _TextSearcher) + + def test_create_from_options_succeeds_with_valid_embedder_content(self): + # Creates with options containing model content successfully. + with open(self.embedder_model_path, 'rb') as f: + options = _TextSearcherOptions( + base_options=_BaseOptions(file_content=f.read()), + search_options=_SearchOptions(index_file_name=self.index_path)) + searcher = _TextSearcher.create_from_options(options) + self.assertIsInstance(searcher, _TextSearcher) + + def test_create_from_options_succeeds_with_valid_searcher_content(self): + # Creates with options containing model content successfully. + with open(self.searcher_model_path, 'rb') as f: + options = _TextSearcherOptions( + base_options=_BaseOptions(file_content=f.read()), + search_options=_SearchOptions(index_file_name=self.index_path)) + searcher = _TextSearcher.create_from_options(options) + self.assertIsInstance(searcher, _TextSearcher) + + def test_create_from_options_succeeds_with_valid_index_content(self): + # Creates with options containing index content successfully. + with open(self.index_path, 'rb') as f: + options = _TextSearcherOptions( + base_options=_BaseOptions(file_name=self.embedder_model_path), + search_options=_SearchOptions(index_file_content=f.read())) + searcher = _TextSearcher.create_from_options(options) + self.assertIsInstance(searcher, _TextSearcher) + + def test_create_from_options_fails_with_invalid_index_path(self): + # Invalid index path. + with self.assertRaisesRegex( + ValueError, + r'Unable to find index file: SearchOptions.index_file is not set and ' + r'no AssociatedFile with type SCANN_INDEX_FILE could be found in the ' + r'output tensor metadata.'): + options = _TextSearcherOptions( + base_options=_BaseOptions(file_name=self.embedder_model_path)) + _TextSearcher.create_from_options(options) + + def test_create_from_options_fails_with_invalid_model_path(self): + # Invalid empty model path. + with self.assertRaisesRegex( + ValueError, + r"ExternalFile must specify at least one of 'file_content', " + r"'file_name' or 'file_descriptor_meta'."): + options = _TextSearcherOptions( + base_options=_BaseOptions(file_name=''), + search_options=_SearchOptions(index_file_name=self.index_path)) + _TextSearcher.create_from_options(options) + + def test_create_from_options_fails_with_invalid_quantization(self): + # Invalid quantization option. + with self.assertRaisesRegex( + ValueError, + r'Setting EmbeddingOptions.quantize = true is not allowed in ' + r'searchers.'): + options = _TextSearcherOptions( + base_options=_BaseOptions(file_name=self.embedder_model_path), + embedding_options=_EmbeddingOptions(quantize=True), + search_options=_SearchOptions(index_file_name=self.index_path)) + _TextSearcher.create_from_options(options) + + def test_create_from_options_fails_with_invalid_max_results(self): + # Invalid max results option. + with self.assertRaisesRegex( + ValueError, r'SearchOptions.max_results must be > 0, found -1.'): + options = _TextSearcherOptions( + base_options=_BaseOptions(file_name=self.embedder_model_path), + search_options=_SearchOptions( + index_file_name=self.index_path, max_results=-1)) + _TextSearcher.create_from_options(options) + + def test_search_with_default_options(self): + # Create searcher. + searcher = _TextSearcher.create_from_file(self.embedder_model_path, + self.index_path) + + # Perform text search. + text_search_result = searcher.search('The weather was excellent.') + + self.assertProtoEquals(_EXPECTED_REGEX_DEFAULT_OPTIONS_SEARCH_RESULT, + text_search_result) + + @parameterized.parameters( + (_REGEX_EMBEDDER_MODEL, _REGEX_INDEX, ModelFileType.FILE_NAME, + IndexFileType.FILE_NAME, _EXPECTED_REGEX_SEARCH_RESULT), + (_REGEX_EMBEDDER_MODEL, _REGEX_INDEX, ModelFileType.FILE_CONTENT, + IndexFileType.FILE_NAME, _EXPECTED_REGEX_SEARCH_RESULT), + (_REGEX_EMBEDDER_MODEL, _REGEX_INDEX, ModelFileType.FILE_NAME, + IndexFileType.FILE_CONTENT, _EXPECTED_REGEX_SEARCH_RESULT), + (_REGEX_EMBEDDER_MODEL, _REGEX_INDEX, ModelFileType.FILE_CONTENT, + IndexFileType.FILE_CONTENT, _EXPECTED_REGEX_SEARCH_RESULT), + (_REGEX_SEARCHER_MODEL, None, ModelFileType.FILE_NAME, IndexFileType.NONE, + _EXPECTED_REGEX_SEARCH_RESULT), + (_REGEX_SEARCHER_MODEL, None, ModelFileType.FILE_CONTENT, + IndexFileType.NONE, _EXPECTED_REGEX_SEARCH_RESULT), + (_BERT_EMBEDDER_MODEL, _BERT_INDEX, ModelFileType.FILE_NAME, + IndexFileType.FILE_NAME, _EXPECTED_BERT_SEARCH_RESULT), + (_BERT_EMBEDDER_MODEL, _BERT_INDEX, ModelFileType.FILE_CONTENT, + IndexFileType.FILE_NAME, _EXPECTED_BERT_SEARCH_RESULT), + (_BERT_EMBEDDER_MODEL, _BERT_INDEX, ModelFileType.FILE_NAME, + IndexFileType.FILE_CONTENT, _EXPECTED_BERT_SEARCH_RESULT), + (_BERT_EMBEDDER_MODEL, _BERT_INDEX, ModelFileType.FILE_CONTENT, + IndexFileType.FILE_CONTENT, _EXPECTED_BERT_SEARCH_RESULT), + (_BERT_SEARCHER_MODEL, None, ModelFileType.FILE_NAME, IndexFileType.NONE, + _EXPECTED_BERT_SEARCH_RESULT), + (_BERT_SEARCHER_MODEL, None, ModelFileType.FILE_CONTENT, + IndexFileType.NONE, _EXPECTED_BERT_SEARCH_RESULT), + (_USE_EMBEDDER_MODEL, _USE_INDEX, ModelFileType.FILE_NAME, + IndexFileType.FILE_NAME, _EXPECTED_USE_SEARCH_RESULT), + (_USE_EMBEDDER_MODEL, _USE_INDEX, ModelFileType.FILE_CONTENT, + IndexFileType.FILE_NAME, _EXPECTED_USE_SEARCH_RESULT), + (_USE_EMBEDDER_MODEL, _USE_INDEX, ModelFileType.FILE_NAME, + IndexFileType.FILE_CONTENT, _EXPECTED_USE_SEARCH_RESULT), + (_USE_EMBEDDER_MODEL, _USE_INDEX, ModelFileType.FILE_CONTENT, + IndexFileType.FILE_CONTENT, _EXPECTED_USE_SEARCH_RESULT), + (_USE_SEARCHER_MODEL, None, ModelFileType.FILE_NAME, IndexFileType.NONE, + _EXPECTED_USE_SEARCH_RESULT), + (_USE_SEARCHER_MODEL, None, ModelFileType.FILE_CONTENT, + IndexFileType.NONE, _EXPECTED_USE_SEARCH_RESULT), + ) + def test_search(self, model_name, index_name, model_file_type, + index_file_type, expected_result_text_proto): + # Create BaseOptions. + model_path = test_util.get_test_data_path(model_name) + if model_file_type is ModelFileType.FILE_NAME: + base_options = _BaseOptions(file_name=model_path) + elif model_file_type is ModelFileType.FILE_CONTENT: + with open(model_path, 'rb') as f: + model_content = f.read() + base_options = _BaseOptions(file_content=model_content) + else: + # Should never happen + raise ValueError('model_file_type is invalid.') + + # Create SearchOptions. + if index_file_type is IndexFileType.NONE: + search_options = _SearchOptions() + else: + index_path = test_util.get_test_data_path(index_name) + if index_file_type is IndexFileType.FILE_NAME: + search_options = _SearchOptions(index_file_name=index_path) + elif index_file_type is IndexFileType.FILE_CONTENT: + with open(index_path, 'rb') as f: + index_content = f.read() + search_options = _SearchOptions(index_file_content=index_content) + else: + # Should never happen + raise ValueError('index_file_type is invalid.') + + # Create searcher. + options = _TextSearcherOptions( + base_options, _EmbeddingOptions(l2_normalize=True, quantize=False), + search_options) + searcher = _TextSearcher.create_from_options(options) + + # Perform text search. + text_search_result = searcher.search('The weather was excellent.') + + # Comparing results. + self.assertProtoEquals(expected_result_text_proto, text_search_result) + + # Get user info and compare values. + self.assertEqual(searcher.get_user_info(), 'userinfo') + + def test_max_results_option(self): + # Create searcher. + base_options = _BaseOptions(file_name=self.embedder_model_path) + search_options = _SearchOptions( + index_file_name=self.index_path, max_results=_MAX_RESULTS) + options = _TextSearcherOptions(base_options, + _EmbeddingOptions(l2_normalize=True), + search_options) + searcher = _TextSearcher.create_from_options(options) + + # Perform text search. + text_search_result = searcher.search('The weather was excellent.') + nearest_neighbors = text_search_result.nearest_neighbors + + self.assertLessEqual( + len(nearest_neighbors), _MAX_RESULTS, 'Too many results returned.') + + +if __name__ == '__main__': + tf.test.main()
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/vision/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/vision/BUILD index 7714262..88c6d32 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/vision/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/vision/BUILD
@@ -13,6 +13,7 @@ "//tensorflow_lite_support/cc/test/testdata/task/vision:test_models", ], deps = [ + # build rule placeholder: numpy dep, # build rule placeholder: tensorflow dep, "//tensorflow_lite_support/python/task/core/proto:base_options_py_pb2", "//tensorflow_lite_support/python/task/processor/proto:bounding_box_pb2", @@ -33,17 +34,15 @@ "//tensorflow_lite_support/cc/test/testdata/task/vision:test_models", ], deps = [ + # build rule placeholder: tensorflow dep, "//tensorflow_lite_support/python/task/core/proto:base_options_py_pb2", "//tensorflow_lite_support/python/task/processor/proto:bounding_box_pb2", - "//tensorflow_lite_support/python/task/processor/proto:class_pb2", "//tensorflow_lite_support/python/task/processor/proto:classification_options_pb2", "//tensorflow_lite_support/python/task/processor/proto:classifications_pb2", "//tensorflow_lite_support/python/task/vision:image_classifier", "//tensorflow_lite_support/python/task/vision/core:tensor_image", - "//tensorflow_lite_support/python/test:base_test", "//tensorflow_lite_support/python/test:test_util", "@absl_py//absl/testing:parameterized", - "@com_google_protobuf//:protobuf_python", ], ) @@ -55,12 +54,34 @@ "//tensorflow_lite_support/cc/test/testdata/task/vision:test_models", ], deps = [ + # build rule placeholder: numpy dep, # build rule placeholder: tensorflow dep, "//tensorflow_lite_support/python/task/core/proto:base_options_py_pb2", "//tensorflow_lite_support/python/task/processor/proto:segmentation_options_pb2", + "//tensorflow_lite_support/python/task/processor/proto:segmentations_pb2", "//tensorflow_lite_support/python/task/vision:image_segmenter", "//tensorflow_lite_support/python/task/vision/core:tensor_image", - "//tensorflow_lite_support/python/test:base_test", + "//tensorflow_lite_support/python/test:test_util", + "@absl_py//absl/testing:parameterized", + ], +) + +py_test( + name = "image_searcher_test", + srcs = ["image_searcher_test.py"], + data = [ + "//tensorflow_lite_support/cc/test/testdata/task/vision:test_images", + "//tensorflow_lite_support/cc/test/testdata/task/vision:test_indices", + "//tensorflow_lite_support/cc/test/testdata/task/vision:test_models", + ], + deps = [ + # build rule placeholder: tensorflow dep, + "//tensorflow_lite_support/python/task/core/proto:base_options_py_pb2", + "//tensorflow_lite_support/python/task/processor/proto:bounding_box_pb2", + "//tensorflow_lite_support/python/task/processor/proto:embedding_options_pb2", + "//tensorflow_lite_support/python/task/processor/proto:search_options_py_pb2", + "//tensorflow_lite_support/python/task/vision:image_searcher", + "//tensorflow_lite_support/python/task/vision/core:tensor_image", "//tensorflow_lite_support/python/test:test_util", "@absl_py//absl/testing:parameterized", ], @@ -74,16 +95,13 @@ "//tensorflow_lite_support/cc/test/testdata/task/vision:test_models", ], deps = [ + # build rule placeholder: tensorflow dep, "//tensorflow_lite_support/python/task/core/proto:base_options_py_pb2", - "//tensorflow_lite_support/python/task/processor/proto:bounding_box_pb2", - "//tensorflow_lite_support/python/task/processor/proto:class_pb2", "//tensorflow_lite_support/python/task/processor/proto:detection_options_pb2", "//tensorflow_lite_support/python/task/processor/proto:detections_pb2", "//tensorflow_lite_support/python/task/vision:object_detector", "//tensorflow_lite_support/python/task/vision/core:tensor_image", - "//tensorflow_lite_support/python/test:base_test", "//tensorflow_lite_support/python/test:test_util", "@absl_py//absl/testing:parameterized", - "@com_google_protobuf//:protobuf_python", ], )
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/vision/core/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/vision/core/BUILD new file mode 100644 index 0000000..6287713f --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/vision/core/BUILD
@@ -0,0 +1,23 @@ +# Placeholder for internal Python strict test compatibility macro. + +package( + default_visibility = ["//tensorflow_lite_support:internal"], + licenses = ["notice"], # Apache 2.0 +) + +py_test( + name = "tensor_image_test", + srcs = ["tensor_image_test.py"], + data = [ + "//tensorflow_lite_support/cc/test/testdata/task/vision:test_images", + ], + deps = [ + # build rule placeholder: numpy dep, + # build rule placeholder: tensorflow dep, + "//tensorflow_lite_support/python/task/vision/core:color_space_type", + "//tensorflow_lite_support/python/task/vision/core:tensor_image", + "//tensorflow_lite_support/python/task/vision/core/pybinds:image_utils", + "//tensorflow_lite_support/python/test:test_util", + "@absl_py//absl/testing:parameterized", + ], +)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/core/tensor_image_test.py b/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/vision/core/tensor_image_test.py similarity index 97% rename from third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/core/tensor_image_test.py rename to third_party/tflite_support/src/tensorflow_lite_support/python/test/task/vision/core/tensor_image_test.py index d6d4761..0a74c273 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/core/tensor_image_test.py +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/vision/core/tensor_image_test.py
@@ -51,7 +51,7 @@ self.assertEqual(image.width, width) self.assertEqual(image.color_space_type, color_type) self.assertIsInstance(image.buffer, np.ndarray) - self.assertAllEqual(image.buffer, array) + self.assertAllClose(image.buffer, array) if __name__ == '__main__':
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/vision/image_classifier_test.py b/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/vision/image_classifier_test.py index dff6110..654c02e 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/vision/image_classifier_test.py +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/vision/image_classifier_test.py
@@ -14,21 +14,15 @@ """Tests for image_classifier.""" import enum -import json from absl.testing import parameterized -from google.protobuf import json_format -# TODO(b/220067158): Change to import tensorflow and leverage tf.test once -# fixed the dependency issue. -import unittest +import tensorflow as tf + from tensorflow_lite_support.python.task.core.proto import base_options_pb2 from tensorflow_lite_support.python.task.processor.proto import bounding_box_pb2 -from tensorflow_lite_support.python.task.processor.proto import class_pb2 from tensorflow_lite_support.python.task.processor.proto import classification_options_pb2 -from tensorflow_lite_support.python.task.processor.proto import classifications_pb2 from tensorflow_lite_support.python.task.vision import image_classifier from tensorflow_lite_support.python.task.vision.core import tensor_image -from tensorflow_lite_support.python.test import base_test from tensorflow_lite_support.python.test import test_util _BaseOptions = base_options_pb2.BaseOptions @@ -41,7 +35,6 @@ _DENY_LIST = ['cheeseburger'] _SCORE_THRESHOLD = 0.5 _MAX_RESULTS = 3 -_ACCEPTABLE_ERROR_RANGE = 0.000001 def _create_classifier_from_options(base_options, **classification_options): @@ -53,23 +46,12 @@ return classifier -def _build_test_data(expected_categories): - classifications = classifications_pb2.Classifications(head_index=0) - classifications.classes.extend( - [class_pb2.Category(**args) for args in expected_categories]) - expected_result = classifications_pb2.ClassificationResult() - expected_result.classifications.append(classifications) - expected_result_dict = json.loads(json_format.MessageToJson(expected_result)) - - return expected_result_dict - - class ModelFileType(enum.Enum): FILE_CONTENT = 1 FILE_NAME = 2 -class ImageClassifierTest(parameterized.TestCase, base_test.BaseTestCase): +class ImageClassifierTest(parameterized.TestCase, tf.test.TestCase): def setUp(self): super().setUp() @@ -106,33 +88,55 @@ classifier = _ImageClassifier.create_from_options(options) self.assertIsInstance(classifier, _ImageClassifier) - @parameterized.parameters((ModelFileType.FILE_NAME, 3, [{ - 'index': 934, - 'score': 0.7399742007255554, - 'class_name': 'cheeseburger' - }, { - 'index': 925, - 'score': 0.026928534731268883, - 'class_name': 'guacamole' - }, { - 'index': 932, - 'score': 0.025737214833498, - 'class_name': 'bagel' - }]), (ModelFileType.FILE_CONTENT, 3, [{ - 'index': 934, - 'score': 0.7399742007255554, - 'class_name': 'cheeseburger' - }, { - 'index': 925, - 'score': 0.026928534731268883, - 'class_name': 'guacamole' - }, { - 'index': 932, - 'score': 0.025737214833498, - 'class_name': 'bagel' - }])) + @parameterized.parameters((ModelFileType.FILE_NAME, 3, """ + classifications { + classes { + index: 934 + score: 0.739974 + display_name: "" + class_name: "cheeseburger" + } + classes { + index: 925 + score: 0.026929 + display_name: "" + class_name: "guacamole" + } + classes { + index: 932 + score: 0.025737 + display_name: "" + class_name: "bagel" + } + head_index: 0 + head_name: "" + } + """), (ModelFileType.FILE_CONTENT, 3, """ + classifications { + classes { + index: 934 + score: 0.739974 + display_name: "" + class_name: "cheeseburger" + } + classes { + index: 925 + score: 0.026929 + display_name: "" + class_name: "guacamole" + } + classes { + index: 932 + score: 0.025737 + display_name: "" + class_name: "bagel" + } + head_index: 0 + head_name: "" + } + """)) def test_classify_model(self, model_file_type, max_results, - expected_categories): + expected_result_text_proto): # Creates classifier. if model_file_type is ModelFileType.FILE_NAME: base_options = _BaseOptions(file_name=self.model_path) @@ -152,14 +156,9 @@ # Classifies the input. image_result = classifier.classify(image, bounding_box=None) - image_result_dict = json.loads(json_format.MessageToJson(image_result)) - - # Builds test data. - expected_result_dict = _build_test_data(expected_categories) # Comparing results (classification w/o bounding box). - self.assertDeepAlmostEqual( - image_result_dict, expected_result_dict, delta=_ACCEPTABLE_ERROR_RANGE) + self.assertProtoEquals(expected_result_text_proto, image_result.to_pb2()) def test_classify_model_with_bounding_box(self): # Creates classifier. @@ -176,29 +175,35 @@ # Classifies the input. image_result = classifier.classify(image, bounding_box) - image_result_dict = json.loads(json_format.MessageToJson(image_result)) # Expected results. - expected_categories = [{ - 'index': 934, - 'score': 0.8815076351165771, - 'class_name': 'cheeseburger' - }, { - 'index': 925, - 'score': 0.019456762820482254, - 'class_name': 'guacamole' - }, { - 'index': 932, - 'score': 0.012489477172493935, - 'class_name': 'bagel' - }] - - # Builds test data. - expected_result_dict = _build_test_data(expected_categories) + expected_result_text_proto = """ + classifications { + classes { + index: 934 + score: 0.881507 + display_name: "" + class_name: "cheeseburger" + } + classes { + index: 925 + score: 0.019457 + display_name: "" + class_name: "guacamole" + } + classes { + index: 932 + score: 0.012489 + display_name: "" + class_name: "bagel" + } + head_index: 0 + head_name: "" + } + """ # Comparing results (classification w/ bounding box). - self.assertDeepAlmostEqual( - image_result_dict, expected_result_dict, delta=_ACCEPTABLE_ERROR_RANGE) + self.assertProtoEquals(expected_result_text_proto, image_result.to_pb2()) def test_max_results_option(self): # Creates classifier. @@ -212,9 +217,7 @@ # Classifies the input. image_result = classifier.classify(image, bounding_box=None) - image_result_dict = json.loads(json_format.MessageToJson(image_result)) - - categories = image_result_dict['classifications'][0]['classes'] + categories = image_result.classifications[0].categories self.assertLessEqual( len(categories), _MAX_RESULTS, 'Too many results returned.') @@ -231,59 +234,50 @@ # Classifies the input. image_result = classifier.classify(image, bounding_box=None) - image_result_dict = json.loads(json_format.MessageToJson(image_result)) - - categories = image_result_dict['classifications'][0]['classes'] + categories = image_result.classifications[0].categories for category in categories: - score = category['score'] self.assertGreaterEqual( - score, _SCORE_THRESHOLD, - 'Classification with score lower than threshold found. {0}'.format( - category)) + category.score, _SCORE_THRESHOLD, + f'Classification with score lower than threshold found. {category}') def test_allowlist_option(self): # Creates classifier. base_options = _BaseOptions(file_name=self.model_path) classifier = _create_classifier_from_options( - base_options, class_name_allowlist=_ALLOW_LIST) + base_options, category_name_allowlist=_ALLOW_LIST) # Loads image. image = tensor_image.TensorImage.create_from_file(self.test_image_path) # Classifies the input. image_result = classifier.classify(image, bounding_box=None) - image_result_dict = json.loads(json_format.MessageToJson(image_result)) - - categories = image_result_dict['classifications'][0]['classes'] + categories = image_result.classifications[0].categories for category in categories: - label = category['className'] - self.assertIn( - label, _ALLOW_LIST, - 'Label "{0}" found but not in label allow list'.format(label)) + label = category.category_name + self.assertIn(label, _ALLOW_LIST, + f'Label {label} found but not in label allow list') def test_denylist_option(self): # Creates classifier. base_options = _BaseOptions(file_name=self.model_path) classifier = _create_classifier_from_options( - base_options, score_threshold=0.01, class_name_denylist=_DENY_LIST) + base_options, score_threshold=0.01, category_name_denylist=_DENY_LIST) # Loads image image = tensor_image.TensorImage.create_from_file(self.test_image_path) # Classifies the input. image_result = classifier.classify(image, bounding_box=None) - image_result_dict = json.loads(json_format.MessageToJson(image_result)) - - categories = image_result_dict['classifications'][0]['classes'] + categories = image_result.classifications[0].categories for category in categories: - label = category['className'] + label = category.category_name self.assertNotIn(label, _DENY_LIST, - 'Label "{0}" found but in deny list.'.format(label)) + f'Label {label} found but in deny list.') def test_combined_allowlist_and_denylist(self): # Fails with combined allowlist and denylist @@ -293,7 +287,7 @@ r'exclusive options.'): base_options = _BaseOptions(file_name=self.model_path) classification_options = classification_options_pb2.ClassificationOptions( - class_name_allowlist=['foo'], class_name_denylist=['bar']) + category_name_allowlist=['foo'], category_name_denylist=['bar']) options = _ImageClassifierOptions( base_options=base_options, classification_options=classification_options) @@ -301,4 +295,4 @@ if __name__ == '__main__': - unittest.main() + tf.test.main()
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/vision/image_embedder_test.py b/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/vision/image_embedder_test.py index d3514ee..5336f28 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/vision/image_embedder_test.py +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/vision/image_embedder_test.py
@@ -16,8 +16,9 @@ import enum from absl.testing import parameterized - +import numpy as np import tensorflow as tf + from tensorflow_lite_support.python.task.core.proto import base_options_pb2 from tensorflow_lite_support.python.task.processor.proto import bounding_box_pb2 from tensorflow_lite_support.python.task.processor.proto import embedding_options_pb2 @@ -26,6 +27,7 @@ from tensorflow_lite_support.python.task.vision.core import tensor_image from tensorflow_lite_support.python.test import test_util + _BaseOptions = base_options_pb2.BaseOptions _ImageEmbedder = image_embedder.ImageEmbedder _ImageEmbedderOptions = image_embedder.ImageEmbedderOptions @@ -74,13 +76,13 @@ self.assertIsInstance(embedder, _ImageEmbedder) @parameterized.parameters( - (False, False, False, ModelFileType.FILE_NAME, 0.932738), - (True, False, False, ModelFileType.FILE_NAME, 0.932738), - (True, True, False, ModelFileType.FILE_CONTENT, 0.929717), - (False, False, True, ModelFileType.FILE_CONTENT, 0.999914), + (False, False, False, ModelFileType.FILE_NAME, 0.932738, -0.20580328), + (True, False, False, ModelFileType.FILE_NAME, 0.932738, -0.0135661615), + (True, True, False, ModelFileType.FILE_CONTENT, 0.929717, 254), + (False, False, True, ModelFileType.FILE_CONTENT, 0.999914, -0.16619979), ) def test_embed(self, l2_normalize, quantize, with_bounding_box, - model_file_type, expected_similarity): + model_file_type, expected_similarity, expected_first_value): # Creates embedder. if model_file_type is ModelFileType.FILE_NAME: base_options = _BaseOptions(file_name=self.model_path) @@ -119,12 +121,17 @@ image_feature_vector = image_result.embeddings[0].feature_vector self.assertLen(crop_result.embeddings, 1) crop_feature_vector = crop_result.embeddings[0].feature_vector + + self.assertLen(image_feature_vector.value, 1024) + self.assertLen(crop_feature_vector.value, 1024) + if quantize: - self.assertLen(image_feature_vector.value_string, 1024) - self.assertLen(crop_feature_vector.value_string, 1024) + self.assertEqual(image_feature_vector.value.dtype, np.uint8) else: - self.assertLen(image_feature_vector.value_float, 1024) - self.assertLen(crop_feature_vector.value_float, 1024) + self.assertEqual(image_feature_vector.value.dtype, float) + + # Check embedding value. + self.assertAlmostEqual(image_feature_vector.value[0], expected_first_value) # Checks cosine similarity. similarity = embedder.cosine_similarity(image_feature_vector, @@ -137,16 +144,15 @@ embedder = _ImageEmbedder.create_from_options(options) # Builds test data. - embedding = embedding_pb2.Embedding(output_index=0) - embedding.feature_vector.value_float.append(1.0) - embedding.feature_vector.value_float.append(0.0) - embedding_result = embedding_pb2.EmbeddingResult() - embedding_result.embeddings.append(embedding) + feature_vector = embedding_pb2.FeatureVector(value=np.array([1.0, 0.0])) + embedding = embedding_pb2.Embedding( + output_index=0, feature_vector=feature_vector) + embedding_result = embedding_pb2.EmbeddingResult(embeddings=[embedding]) result0 = embedder.get_embedding_by_index(embedding_result, 0) self.assertEqual(result0.output_index, 0) - self.assertEqual(result0.feature_vector.value_float[0], 1.0) - self.assertEqual(result0.feature_vector.value_float[1], 0.0) + self.assertEqual(result0.feature_vector.value[0], 1.0) + self.assertEqual(result0.feature_vector.value[1], 0.0) with self.assertRaisesRegex(ValueError, r"Output index is out of bound\."): embedder.get_embedding_by_index(embedding_result, 1)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/vision/image_searcher_test.py b/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/vision/image_searcher_test.py new file mode 100644 index 0000000..3ed13a49 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/vision/image_searcher_test.py
@@ -0,0 +1,282 @@ +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for image_searcher.""" + +import enum + +from absl.testing import parameterized + +import tensorflow as tf +from tensorflow_lite_support.python.task.core.proto import base_options_pb2 +from tensorflow_lite_support.python.task.processor.proto import bounding_box_pb2 +from tensorflow_lite_support.python.task.processor.proto import embedding_options_pb2 +from tensorflow_lite_support.python.task.processor.proto import search_options_pb2 +from tensorflow_lite_support.python.task.vision import image_searcher +from tensorflow_lite_support.python.task.vision.core import tensor_image +from tensorflow_lite_support.python.test import test_util + +_BaseOptions = base_options_pb2.BaseOptions +_EmbeddingOptions = embedding_options_pb2.EmbeddingOptions +_SearchOptions = search_options_pb2.SearchOptions +_ImageSearcher = image_searcher.ImageSearcher +_ImageSearcherOptions = image_searcher.ImageSearcherOptions + +_MOBILENET_EMBEDDER_MODEL = 'mobilenet_v3_small_100_224_embedder.tflite' +_MOBILENET_SEARCHER_MODEL = 'mobilenet_v3_small_100_224_searcher.tflite' +_MOBILENET_INDEX = 'searcher_index.ldb' + +_IMAGE_FILE = 'burger.jpg' +_MAX_RESULTS = 2 + + +class ModelFileType(enum.Enum): + FILE_CONTENT = 1 + FILE_NAME = 2 + + +class IndexFileType(enum.Enum): + NONE = 1 + FILE_CONTENT = 2 + FILE_NAME = 3 + + +class ImageSearcherTest(parameterized.TestCase, tf.test.TestCase): + + def setUp(self): + super().setUp() + self.test_image_path = test_util.get_test_data_path(_IMAGE_FILE) + self.embedder_model_path = test_util.get_test_data_path( + _MOBILENET_EMBEDDER_MODEL) + self.searcher_model_path = test_util.get_test_data_path( + _MOBILENET_SEARCHER_MODEL) + self.index_path = test_util.get_test_data_path(_MOBILENET_INDEX) + + def test_create_from_file_succeeds_with_valid_embedder_and_index_paths(self): + # Creates with default option and valid model and index files successfully. + searcher = _ImageSearcher.create_from_file(self.embedder_model_path, + self.index_path) + self.assertIsInstance(searcher, _ImageSearcher) + + def test_create_from_file_succeeds_with_valid_searcher_path(self): + # Creates with default option and valid searcher model. + searcher = _ImageSearcher.create_from_file(self.searcher_model_path) + self.assertIsInstance(searcher, _ImageSearcher) + + def test_create_from_options_succeeds_with_valid_embedder_and_index_paths( + self): + options = _ImageSearcherOptions( + base_options=_BaseOptions(file_name=self.embedder_model_path), + search_options=_SearchOptions(index_file_name=self.index_path)) + searcher = _ImageSearcher.create_from_options(options) + self.assertIsInstance(searcher, _ImageSearcher) + + def test_create_from_options_succeeds_with_valid_searcher_path(self): + options = _ImageSearcherOptions( + base_options=_BaseOptions(file_name=self.searcher_model_path), + search_options=_SearchOptions()) + searcher = _ImageSearcher.create_from_options(options) + self.assertIsInstance(searcher, _ImageSearcher) + + def test_create_from_options_succeeds_with_valid_embedder_content(self): + # Creates with options containing model content successfully. + with open(self.embedder_model_path, 'rb') as f: + options = _ImageSearcherOptions( + base_options=_BaseOptions(file_content=f.read()), + search_options=_SearchOptions(index_file_name=self.index_path)) + searcher = _ImageSearcher.create_from_options(options) + self.assertIsInstance(searcher, _ImageSearcher) + + def test_create_from_options_succeeds_with_valid_searcher_content(self): + # Creates with options containing model content successfully. + with open(self.searcher_model_path, 'rb') as f: + options = _ImageSearcherOptions( + base_options=_BaseOptions(file_content=f.read()), + search_options=_SearchOptions()) + searcher = _ImageSearcher.create_from_options(options) + self.assertIsInstance(searcher, _ImageSearcher) + + def test_create_from_options_succeeds_with_valid_index_content(self): + # Creates with options containing index content successfully. + with open(self.index_path, 'rb') as f: + options = _ImageSearcherOptions( + base_options=_BaseOptions(file_name=self.embedder_model_path), + search_options=_SearchOptions(index_file_content=f.read())) + searcher = _ImageSearcher.create_from_options(options) + self.assertIsInstance(searcher, _ImageSearcher) + + def test_create_from_options_fails_with_invalid_index_path(self): + # Invalid index path. + with self.assertRaisesRegex( + ValueError, + r'Unable to find index file: SearchOptions.index_file is not set and ' + r'no AssociatedFile with type SCANN_INDEX_FILE could be found in the ' + r'output tensor metadata.'): + options = _ImageSearcherOptions( + base_options=_BaseOptions(file_name=self.embedder_model_path)) + _ImageSearcher.create_from_options(options) + + def test_create_from_options_fails_with_invalid_model_path(self): + # Invalid empty model path. + with self.assertRaisesRegex( + ValueError, + r"ExternalFile must specify at least one of 'file_content', " + r"'file_name' or 'file_descriptor_meta'."): + options = _ImageSearcherOptions( + base_options=_BaseOptions(file_name=''), + search_options=_SearchOptions(index_file_name=self.index_path)) + _ImageSearcher.create_from_options(options) + + def test_create_from_options_fails_with_invalid_quantization(self): + # Invalid quantization option. + with self.assertRaisesRegex( + ValueError, + r'Setting EmbeddingOptions.quantize = true is not allowed in ' + r'searchers.'): + options = _ImageSearcherOptions( + base_options=_BaseOptions(file_name=self.embedder_model_path), + embedding_options=_EmbeddingOptions(quantize=True), + search_options=_SearchOptions(index_file_name=self.index_path)) + _ImageSearcher.create_from_options(options) + + def test_create_from_options_fails_with_invalid_max_results(self): + # Invalid max results option. + with self.assertRaisesRegex( + ValueError, r'SearchOptions.max_results must be > 0, found -1.'): + options = _ImageSearcherOptions( + base_options=_BaseOptions(file_name=self.embedder_model_path), + search_options=_SearchOptions( + index_file_name=self.index_path, max_results=-1)) + _ImageSearcher.create_from_options(options) + + @parameterized.parameters( + (_MOBILENET_EMBEDDER_MODEL, ModelFileType.FILE_NAME, + IndexFileType.FILE_NAME), + (_MOBILENET_EMBEDDER_MODEL, ModelFileType.FILE_CONTENT, + IndexFileType.FILE_NAME), + (_MOBILENET_EMBEDDER_MODEL, ModelFileType.FILE_NAME, + IndexFileType.FILE_CONTENT), + (_MOBILENET_EMBEDDER_MODEL, ModelFileType.FILE_CONTENT, + IndexFileType.FILE_CONTENT), + (_MOBILENET_SEARCHER_MODEL, ModelFileType.FILE_NAME, IndexFileType.NONE), + (_MOBILENET_SEARCHER_MODEL, ModelFileType.FILE_CONTENT, + IndexFileType.NONE), + ) + def test_search(self, model_name, model_file_type, index_file_type): + # Create BaseOptions. + model_path = test_util.get_test_data_path(model_name) + if model_file_type is ModelFileType.FILE_NAME: + base_options = _BaseOptions(file_name=model_path) + elif model_file_type is ModelFileType.FILE_CONTENT: + with open(model_path, 'rb') as f: + model_content = f.read() + base_options = _BaseOptions(file_content=model_content) + else: + # Should never happen + raise ValueError('model_file_type is invalid.') + + # Create SearchOptions. + if index_file_type is IndexFileType.NONE: + search_options = _SearchOptions() + else: + index_path = test_util.get_test_data_path(_MOBILENET_INDEX) + if index_file_type is IndexFileType.FILE_NAME: + search_options = _SearchOptions(index_file_name=index_path) + elif index_file_type is IndexFileType.FILE_CONTENT: + with open(index_path, 'rb') as f: + index_content = f.read() + search_options = _SearchOptions(index_file_content=index_content) + else: + # Should never happen + raise ValueError('index_file_type is invalid.') + + # Create searcher. + options = _ImageSearcherOptions( + base_options, _EmbeddingOptions(l2_normalize=True, quantize=False), + search_options) + searcher = _ImageSearcher.create_from_options(options) + + # Loads image. + image = tensor_image.TensorImage.create_from_file(self.test_image_path) + + # Perform image search. + image_search_result = searcher.search(image) + + # Expected results. + expected_result_text_proto = """ + nearest_neighbors { metadata: "burger" distance: -0.0 } + nearest_neighbors { metadata: "car" distance: 1.822435 } + nearest_neighbors { metadata: "bird" distance: 1.930939 } + nearest_neighbors { metadata: "dog" distance: 2.047355 } + nearest_neighbors { metadata: "cat" distance: 2.075868 } + """ + + # Comparing results. + self.assertProtoEquals(expected_result_text_proto, image_search_result) + + # Get user info and compare values. + self.assertEqual(searcher.get_user_info(), 'userinfo') + + def test_search_with_bounding_box(self): + # Create searcher. + searcher = _ImageSearcher.create_from_file(self.embedder_model_path, + self.index_path) + + # Loads image. + image = tensor_image.TensorImage.create_from_file(self.test_image_path) + + # Bounding box in "burger.jpg" corresponding to "burger_crop.jpg". + bounding_box = bounding_box_pb2.BoundingBox( + origin_x=0, origin_y=0, width=400, height=325) + + # Perform image search. + image_search_result = searcher.search(image, bounding_box) + + # Expected results. + expected_result_text_proto = """ + nearest_neighbors { metadata: "burger" distance: 184.85214 } + nearest_neighbors { metadata: "car" distance: 209.32019 } + nearest_neighbors { metadata: "bird" distance: 211.43195 } + nearest_neighbors { metadata: "dog" distance: 212.77237 } + nearest_neighbors { metadata: "cat" distance: 212.8553 } + """ + + # Comparing results. + self.assertProtoEquals(expected_result_text_proto, image_search_result) + + # Get user info and compare values. + self.assertEqual(searcher.get_user_info(), 'userinfo') + + def test_max_results_option(self): + # Create searcher. + base_options = _BaseOptions(file_name=self.embedder_model_path) + search_options = _SearchOptions( + index_file_name=self.index_path, max_results=_MAX_RESULTS) + options = _ImageSearcherOptions(base_options, + _EmbeddingOptions(l2_normalize=True), + search_options) + searcher = _ImageSearcher.create_from_options(options) + + # Loads image. + image = tensor_image.TensorImage.create_from_file(self.test_image_path) + + # Perform image search. + image_search_result = searcher.search(image) + nearest_neighbors = image_search_result.nearest_neighbors + + self.assertLessEqual( + len(nearest_neighbors), _MAX_RESULTS, 'Too many results returned.') + + +if __name__ == '__main__': + tf.test.main()
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/vision/image_segmenter_test.py b/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/vision/image_segmenter_test.py index c5608c45..866f8ef 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/vision/image_segmenter_test.py +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/vision/image_segmenter_test.py
@@ -16,126 +16,55 @@ import enum from absl.testing import parameterized +import numpy as np import tensorflow as tf + from tensorflow_lite_support.python.task.core.proto import base_options_pb2 from tensorflow_lite_support.python.task.processor.proto import segmentation_options_pb2 +from tensorflow_lite_support.python.task.processor.proto import segmentations_pb2 from tensorflow_lite_support.python.task.vision import image_segmenter from tensorflow_lite_support.python.task.vision.core import tensor_image from tensorflow_lite_support.python.test import test_util _BaseOptions = base_options_pb2.BaseOptions +_ColoredLabel = segmentations_pb2.ColoredLabel +_OutputType = segmentation_options_pb2.OutputType _ImageSegmenter = image_segmenter.ImageSegmenter _ImageSegmenterOptions = image_segmenter.ImageSegmenterOptions _MODEL_FILE = 'deeplabv3.tflite' _IMAGE_FILE = 'segmentation_input_rotation0.jpg' _SEGMENTATION_FILE = 'segmentation_golden_rotation0.png' -_EXPECTED_COLORED_LABELS = [{ - 'r': 0, - 'g': 0, - 'b': 0, - 'class_name': 'background' -}, { - 'r': 128, - 'g': 0, - 'b': 0, - 'class_name': 'aeroplane' -}, { - 'r': 0, - 'g': 128, - 'b': 0, - 'class_name': 'bicycle' -}, { - 'r': 128, - 'g': 128, - 'b': 0, - 'class_name': 'bird' -}, { - 'r': 0, - 'g': 0, - 'b': 128, - 'class_name': 'boat' -}, { - 'r': 128, - 'g': 0, - 'b': 128, - 'class_name': 'bottle' -}, { - 'r': 0, - 'g': 128, - 'b': 128, - 'class_name': 'bus' -}, { - 'r': 128, - 'g': 128, - 'b': 128, - 'class_name': 'car' -}, { - 'r': 64, - 'g': 0, - 'b': 0, - 'class_name': 'cat' -}, { - 'r': 192, - 'g': 0, - 'b': 0, - 'class_name': 'chair' -}, { - 'r': 64, - 'g': 128, - 'b': 0, - 'class_name': 'cow' -}, { - 'r': 192, - 'g': 128, - 'b': 0, - 'class_name': 'dining table' -}, { - 'r': 64, - 'g': 0, - 'b': 128, - 'class_name': 'dog' -}, { - 'r': 192, - 'g': 0, - 'b': 128, - 'class_name': 'horse' -}, { - 'r': 64, - 'g': 128, - 'b': 128, - 'class_name': 'motorbike' -}, { - 'r': 192, - 'g': 128, - 'b': 128, - 'class_name': 'person' -}, { - 'r': 0, - 'g': 64, - 'b': 0, - 'class_name': 'potted plant' -}, { - 'r': 128, - 'g': 64, - 'b': 0, - 'class_name': 'sheep' -}, { - 'r': 0, - 'g': 192, - 'b': 0, - 'class_name': 'sofa' -}, { - 'r': 128, - 'g': 192, - 'b': 0, - 'class_name': 'train' -}, { - 'r': 0, - 'g': 64, - 'b': 128, - 'class_name': 'tv' -}] +_EXPECTED_COLORED_LABELS = [ + _ColoredLabel(color=(0, 0, 0), category_name='background', display_name=''), + _ColoredLabel( + color=(128, 0, 0), category_name='aeroplane', display_name=''), + _ColoredLabel(color=(0, 128, 0), category_name='bicycle', display_name=''), + _ColoredLabel(color=(128, 128, 0), category_name='bird', display_name=''), + _ColoredLabel(color=(0, 0, 128), category_name='boat', display_name=''), + _ColoredLabel(color=(128, 0, 128), category_name='bottle', display_name=''), + _ColoredLabel(color=(0, 128, 128), category_name='bus', display_name=''), + _ColoredLabel(color=(128, 128, 128), category_name='car', display_name=''), + _ColoredLabel(color=(64, 0, 0), category_name='cat', display_name=''), + _ColoredLabel(color=(192, 0, 0), category_name='chair', display_name=''), + _ColoredLabel(color=(64, 128, 0), category_name='cow', display_name=''), + _ColoredLabel( + color=(192, 128, 0), category_name='dining table', display_name=''), + _ColoredLabel(color=(64, 0, 128), category_name='dog', display_name=''), + _ColoredLabel(color=(192, 0, 128), category_name='horse', display_name=''), + _ColoredLabel( + color=(64, 128, 128), category_name='motorbike', display_name=''), + _ColoredLabel( + color=(192, 128, 128), category_name='person', display_name=''), + _ColoredLabel( + color=(0, 64, 0), category_name='potted plant', display_name=''), + _ColoredLabel(color=(128, 64, 0), category_name='sheep', display_name=''), + _ColoredLabel(color=(0, 192, 0), category_name='sofa', display_name=''), + _ColoredLabel(color=(128, 192, 0), category_name='train', display_name=''), + _ColoredLabel(color=(0, 64, 128), category_name='tv', display_name='') +] +_MASK_MAGNIFICATION_FACTOR = 10 +_MATCH_PIXELS_THRESHOLD = 0.01 def _create_segmenter_from_options(base_options, **segmentation_options): @@ -160,6 +89,36 @@ self.test_seg_path = test_util.get_test_data_path(_SEGMENTATION_FILE) self.model_path = test_util.get_test_data_path(_MODEL_FILE) + def test_create_from_file_succeeds_with_valid_model_path(self): + # Creates with default option and valid model file successfully. + segmenter = _ImageSegmenter.create_from_file(self.model_path) + self.assertIsInstance(segmenter, _ImageSegmenter) + + def test_create_from_options_succeeds_with_valid_model_path(self): + # Creates with options containing model file successfully. + base_options = _BaseOptions(file_name=self.model_path) + options = _ImageSegmenterOptions(base_options=base_options) + segmenter = _ImageSegmenter.create_from_options(options) + self.assertIsInstance(segmenter, _ImageSegmenter) + + def test_create_from_options_fails_with_invalid_model_path(self): + # Invalid empty model path. + with self.assertRaisesRegex( + ValueError, + r"ExternalFile must specify at least one of 'file_content', " + r"'file_name' or 'file_descriptor_meta'."): + base_options = _BaseOptions(file_name='') + options = _ImageSegmenterOptions(base_options=base_options) + _ImageSegmenter.create_from_options(options) + + def test_create_from_options_succeeds_with_valid_model_content(self): + # Creates with options containing model content successfully. + with open(self.model_path, 'rb') as f: + base_options = _BaseOptions(file_content=f.read()) + options = _ImageSegmenterOptions(base_options=base_options) + segmenter = _ImageSegmenter.create_from_options(options) + self.assertIsInstance(segmenter, _ImageSegmenter) + @parameterized.parameters( (ModelFileType.FILE_NAME, _EXPECTED_COLORED_LABELS), (ModelFileType.FILE_CONTENT, _EXPECTED_COLORED_LABELS)) @@ -181,20 +140,97 @@ image = tensor_image.TensorImage.create_from_file(self.test_image_path) # Performs image segmentation on the input. - segmentation = segmenter.segment(image).segmentation[0] + segmentation = segmenter.segment(image).segmentations[0] colored_labels = segmentation.colored_labels - # Check if the sizes of the result and expected colored labels are the same. - self.assertEqual( - len(colored_labels), len(expected_colored_labels), - 'Number of colored labels do not match.') - # Comparing results. - for index in range(len(expected_colored_labels)): - for key in expected_colored_labels[index].keys(): - self.assertEqual( - getattr(colored_labels[index], key), - expected_colored_labels[index][key]) + self.assertEqual(colored_labels, expected_colored_labels, + 'Colored labels do not match.') + + def test_segmentation_category_mask(self): + """Check if category mask matches with ground truth.""" + # Creates segmenter. + base_options = _BaseOptions(file_name=self.model_path) + segmenter = _create_segmenter_from_options( + base_options, output_type=_OutputType.CATEGORY_MASK) + + # Loads image. + image = tensor_image.TensorImage.create_from_file(self.test_image_path) + + # Performs image segmentation on the input. + segmentation = segmenter.segment(image).segmentations[0] + result_pixels = segmentation.category_mask.flatten() + + # Check if data type of `confidence_masks` are correct. + self.assertEqual(result_pixels.dtype, np.uint8) + + # Loads ground truth segmentation file. + gt_segmentation = tensor_image.TensorImage.create_from_file( + self.test_seg_path) + gt_segmentation_array = gt_segmentation.buffer + gt_segmentation_shape = gt_segmentation_array.shape + num_pixels = gt_segmentation_shape[0] * gt_segmentation_shape[1] + ground_truth_pixels = gt_segmentation_array.flatten() + + self.assertEqual( + len(result_pixels), len(ground_truth_pixels), + 'Segmentation mask size does not match the ground truth mask size.') + + inconsistent_pixels = 0 + + for index in range(num_pixels): + inconsistent_pixels += ( + result_pixels[index] * _MASK_MAGNIFICATION_FACTOR != + ground_truth_pixels[index]) + + self.assertLessEqual( + inconsistent_pixels / num_pixels, _MATCH_PIXELS_THRESHOLD, + f'Number of pixels in the candidate mask differing from that of the ' + f'ground truth mask exceeds {_MATCH_PIXELS_THRESHOLD}.') + + def test_segmentation_confidence_mask_matches_category_mask(self): + """Check if the confidence mask matches with the category mask.""" + # Create BaseOptions from model file. + base_options = _BaseOptions(file_name=self.model_path) + + # Loads image. + image = tensor_image.TensorImage.create_from_file(self.test_image_path) + + # Run segmentation on the model in CATEGORY_MASK mode. + segmenter = _create_segmenter_from_options( + base_options, output_type=_OutputType.CATEGORY_MASK) + + # Performs image segmentation on the input and gets the category mask. + segmentation = segmenter.segment(image).segmentations[0] + category_mask = segmentation.category_mask + + # Run segmentation on the model in CONFIDENCE_MASK mode. + segmenter = _create_segmenter_from_options( + base_options, output_type=_OutputType.CONFIDENCE_MASK) + + # Performs image segmentation on the input again. + segmentation = segmenter.segment(image).segmentations[0] + # Gets the list of confidence masks and colored_labels. + confidence_masks = segmentation.confidence_masks + colored_labels = segmentation.colored_labels + + # Check if confidence mask shape is correct. + self.assertEqual( + len(confidence_masks), len(colored_labels), + 'Number of confidence masks must match with number of categories.') + + # Gather the confidence masks in a single array `confidence_mask_array`. + confidence_mask_array = np.array( + [confidence_mask.value for confidence_mask in confidence_masks]) + + # Check if data type of `confidence_masks` are correct. + self.assertEqual(confidence_mask_array.dtype, np.float) + + # Compute the category mask from the created confidence mask. + calculated_category_mask = np.argmax(confidence_mask_array, axis=0) + self.assertListEqual( + calculated_category_mask.tolist(), category_mask.tolist(), + 'Confidence mask does not match with the category mask.') if __name__ == '__main__':
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/vision/object_detector_test.py b/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/vision/object_detector_test.py index 80375fd..a483ece0 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/vision/object_detector_test.py +++ b/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/vision/object_detector_test.py
@@ -14,21 +14,14 @@ """Tests for object detector.""" import enum -import json from absl.testing import parameterized -# TODO(b/220067158): Change to import tensorflow and leverage tf.test once -# fixed the dependency issue. -from google.protobuf import json_format -import unittest +import tensorflow as tf + from tensorflow_lite_support.python.task.core.proto import base_options_pb2 -from tensorflow_lite_support.python.task.processor.proto import bounding_box_pb2 -from tensorflow_lite_support.python.task.processor.proto import class_pb2 from tensorflow_lite_support.python.task.processor.proto import detection_options_pb2 -from tensorflow_lite_support.python.task.processor.proto import detections_pb2 from tensorflow_lite_support.python.task.vision import object_detector from tensorflow_lite_support.python.task.vision.core import tensor_image -from tensorflow_lite_support.python.test import base_test from tensorflow_lite_support.python.test import test_util _BaseOptions = base_options_pb2.BaseOptions @@ -37,53 +30,28 @@ _MODEL_FILE = 'coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite' _IMAGE_FILE = 'cats_and_dogs.jpg' -_EXPECTED_DETECTIONS = [ - ({ - 'origin_x': 54, - 'origin_y': 396, - 'width': 393, - 'height': 196 - }, { - 'index': 16, - 'score': 0.64453125, - 'class_name': 'cat' - }), - ({ - 'origin_x': 602, - 'origin_y': 157, - 'width': 394, - 'height': 447 - }, { - 'index': 16, - 'score': 0.59765625, - 'class_name': 'cat' - }), - ({ - 'origin_x': 261, - 'origin_y': 394, - 'width': 179, - 'height': 209 - }, { - 'index': 16, - 'score': 0.5625, - 'class_name': 'cat' - }), - ({ - 'origin_x': 389, - 'origin_y': 197, - 'width': 276, - 'height': 409 - }, { - 'index': 17, - 'score': 0.51171875, - 'class_name': 'dog' - }) -] +_EXPECTED_DETECTIONS = """ +detections { + bounding_box { origin_x: 54 origin_y: 396 width: 393 height: 196 } + classes { index: 16 score: 0.64453125 display_name: "" class_name: "cat" } +} +detections { + bounding_box { origin_x: 602 origin_y: 157 width: 394 height: 447 } + classes { index: 16 score: 0.59765625 display_name: "" class_name: "cat" } +} +detections { + bounding_box { origin_x: 261 origin_y: 394 width: 179 height: 209 } + classes { index: 16 score: 0.5625 display_name: "" class_name: "cat" } +} +detections { + bounding_box { origin_x: 389 origin_y: 197 width: 276 height: 409 } + classes { index: 17 score: 0.51171875 display_name: "" class_name: "dog" } +} +""" _ALLOW_LIST = ['cat', 'dog'] _DENY_LIST = ['cat'] _SCORE_THRESHOLD = 0.3 _MAX_RESULTS = 3 -_ACCEPTABLE_ERROR_RANGE = 0.000001 class ModelFileType(enum.Enum): @@ -100,23 +68,7 @@ return detector -def _build_test_data(expected_detections): - expected_result = detections_pb2.DetectionResult() - - for index in range(len(expected_detections)): - bounding_box, category = expected_detections[index] - detection = detections_pb2.Detection() - detection.bounding_box.CopyFrom( - bounding_box_pb2.BoundingBox(**bounding_box)) - detection.classes.append(class_pb2.Category(**category)) - expected_result.detections.append(detection) - - expected_result_dict = json.loads(json_format.MessageToJson(expected_result)) - - return expected_result_dict - - -class ObjectDetectorTest(parameterized.TestCase, base_test.BaseTestCase): +class ObjectDetectorTest(parameterized.TestCase, tf.test.TestCase): def setUp(self): super().setUp() @@ -157,7 +109,7 @@ (ModelFileType.FILE_NAME, 4, _EXPECTED_DETECTIONS), (ModelFileType.FILE_CONTENT, 4, _EXPECTED_DETECTIONS)) def test_detect_model(self, model_file_type, max_results, - expected_detections): + expected_result_text_proto): # Creates detector. if model_file_type is ModelFileType.FILE_NAME: base_options = _BaseOptions(file_name=self.model_path) @@ -177,14 +129,9 @@ # Performs object detection on the input. image_result = detector.detect(image) - image_result_dict = json.loads(json_format.MessageToJson(image_result)) - - # Builds test data. - expected_result_dict = _build_test_data(expected_detections) # Comparing results. - self.assertDeepAlmostEqual( - image_result_dict, expected_result_dict, delta=_ACCEPTABLE_ERROR_RANGE) + self.assertProtoEquals(expected_result_text_proto, image_result.to_pb2()) def test_score_threshold_option(self): # Creates detector. @@ -197,16 +144,13 @@ # Performs object detection on the input. image_result = detector.detect(image) - image_result_dict = json.loads(json_format.MessageToJson(image_result)) + detections = image_result.detections - categories = image_result_dict['detections'] - - for category in categories: - score = category['classes'][0]['score'] + for detection in detections: + score = detection.categories[0].score self.assertGreaterEqual( score, _SCORE_THRESHOLD, - 'Classification with score lower than threshold found. {0}'.format( - category)) + f'Detection with score lower than threshold found. {detection}') def test_max_results_option(self): # Creates detector. @@ -219,8 +163,7 @@ # Performs object detection on the input. image_result = detector.detect(image) - image_result_dict = json.loads(json_format.MessageToJson(image_result)) - detections = image_result_dict['detections'] + detections = image_result.detections self.assertLessEqual( len(detections), _MAX_RESULTS, 'Too many results returned.') @@ -229,42 +172,37 @@ # Creates detector. base_options = _BaseOptions(file_name=self.model_path) detector = _create_detector_from_options( - base_options, class_name_allowlist=_ALLOW_LIST) + base_options, category_name_allowlist=_ALLOW_LIST) # Loads image. image = tensor_image.TensorImage.create_from_file(self.test_image_path) # Performs object detection on the input. image_result = detector.detect(image) - image_result_dict = json.loads(json_format.MessageToJson(image_result)) + detections = image_result.detections - categories = image_result_dict['detections'] - - for category in categories: - label = category['classes'][0]['className'] - self.assertIn( - label, _ALLOW_LIST, - 'Label "{0}" found but not in label allow list'.format(label)) + for detection in detections: + label = detection.categories[0].category_name + self.assertIn(label, _ALLOW_LIST, + f'Label {label} found but not in label allow list') def test_deny_list_option(self): # Creates detector. base_options = _BaseOptions(file_name=self.model_path) detector = _create_detector_from_options( - base_options, class_name_denylist=_DENY_LIST) + base_options, category_name_denylist=_DENY_LIST) # Loads image. image = tensor_image.TensorImage.create_from_file(self.test_image_path) # Performs object detection on the input. image_result = detector.detect(image) - image_result_dict = json.loads(json_format.MessageToJson(image_result)) + detections = image_result.detections - categories = image_result_dict['detections'] - - for category in categories: - label = category['classes'][0]['className'] + for detection in detections: + label = detection.categories[0].category_name self.assertNotIn(label, _DENY_LIST, - 'Label "{0}" found but in deny list.'.format(label)) + f'Label {label} found but in deny list.') def test_combined_allowlist_and_denylist(self): # Fails with combined allowlist and denylist @@ -274,11 +212,11 @@ r'exclusive options.'): base_options = _BaseOptions(file_name=self.model_path) detection_options = detection_options_pb2.DetectionOptions( - class_name_allowlist=['foo'], class_name_denylist=['bar']) + category_name_allowlist=['foo'], category_name_denylist=['bar']) options = _ObjectDetectorOptions( base_options=base_options, detection_options=detection_options) _ObjectDetector.create_from_options(options) if __name__ == '__main__': - unittest.main() + tf.test.main()
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/README.md b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/README.md new file mode 100644 index 0000000..1b5da3c --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/README.md
@@ -0,0 +1,18 @@ +# On-device Implementation of ScaNN + +[ScaNN](https://github.com/google-research/google-research/tree/master/scann) +(Scalable Nearest Neighbors) is a method for efficient vector similarity search +at scale. This is a simplified version of +[ScaNN](https://github.com/google-research/google-research/tree/master/scann) +that requires less resources to run and only for inference. There's no support +for K-Means partitioning training and quantization training. It supports +retrieval with the following features: + +1. K-Means tree space partitioning. +2. [Asymmetric Hashing](https://research.google/pubs/pub41694/) (AH) + quantization. +3. `dot_product` and `squared_l2` distance measures. Note that for + `dot_product` distance, we return the *negative* dot product. This is to + ensure consistency with `squared_l2` that smaller means closer. +4. Indexing new embeddings, including assigning them to closest partitions and + AH quantize them.
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/index_table_sum.h b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/index_table_sum.h new file mode 100644 index 0000000..67e0e30 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/index_table_sum.h
@@ -0,0 +1,256 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SUPPORT_SCANN_ONDEVICE_CC_CORE_INDEX_TABLE_SUM_H_ +#define TENSORFLOW_LITE_SUPPORT_SCANN_ONDEVICE_CC_CORE_INDEX_TABLE_SUM_H_ + +#include <array> +#include <cstddef> +#include <cstdint> +#include <type_traits> +#include <vector> + +#include "Eigen/Core" // from @eigen +#include "tensorflow_lite_support/scann_ondevice/cc/core/simd_utils.h" + +namespace tflite { +namespace scann_ondevice { +namespace core { + +template <typename LutType> +void RearrangeLUT(const LutType* input_data, + int batch_elems, + int batch_size, + LutType* const output_data) { + std::vector<int64_t> simd_sizes; + if (std::is_same<LutType, float>::value) { +#ifdef __AVX__ + simd_sizes = {8, 4}; +#elif defined __SSE__ + simd_sizes = {4}; +#elif defined __ARM_NEON__ + simd_sizes = {4}; +#endif + } else { +#ifdef __AVX2__ + simd_sizes = {16, 8}; +#elif defined __SSE4_1__ + simd_sizes = {8}; +#elif defined __ARM_NEON__ + simd_sizes = {8}; +#endif + } + + int64_t offset = 0; + for (int64_t simd_size : simd_sizes) { + const int64_t num_simds = batch_size / simd_size; + const int64_t simd_batch_elems = simd_size * batch_elems; + for (; offset < num_simds * simd_batch_elems; offset += simd_batch_elems) { + using RowMajorMatrix = Eigen::Matrix<LutType, Eigen::Dynamic, + Eigen::Dynamic, Eigen::RowMajor>; + Eigen::Map<const RowMajorMatrix> input_map(input_data + offset, simd_size, + batch_elems); + Eigen::Map<RowMajorMatrix> output_map(output_data + offset, batch_elems, + simd_size); + output_map = input_map.transpose(); + } + } + std::copy(input_data + offset, input_data + batch_elems * batch_size, + output_data + offset); +} +const int kDefaultChunksPerBlock = 32; +const int k16CentersUint8LutChunksPerBlock = 256; +const int kUnrollSteps = 6; + +template <typename T> +struct MaxQuantizationValue { + static_assert(std::is_same<T, float>::value, "Invalid lookup table type."); + static constexpr size_t value = 0; +}; + +template <> +struct MaxQuantizationValue<uint8_t> { + static constexpr size_t value = 255; +}; + +template <> +struct MaxQuantizationValue<uint16_t> { + static constexpr size_t value = (1 << 16) / kDefaultChunksPerBlock - 1; +}; + +template <typename SimdType, typename LutType, size_t NumCenters = 0> +size_t IndexTableSumSimdBatch(const uint8_t* indices, + size_t num_chunks, + size_t num_outputs, + const LutType* lookup_table, + size_t batch_size, + size_t num_centers, + float min, + float max, + size_t batch_index, + float* const output) { + if (num_centers == 256) { + return IndexTableSumSimdBatch<SimdType, LutType, 256>( + indices, num_chunks, num_outputs, lookup_table, batch_size, 0, min, max, + batch_index, output); + } + const size_t lut_chunk_stride = NumCenters ? NumCenters * SimdType::size() + : num_centers * SimdType::size(); + const size_t lut_item_stride = + NumCenters ? NumCenters * num_chunks : num_chunks * num_centers; + constexpr bool must_dequantize = !std::is_same<LutType, float>::value; + constexpr size_t max_qval = MaxQuantizationValue<LutType>::value; + const float dq_scale = must_dequantize ? (max - min) / max_qval : 0.0f; + const float dq_offset_1 = must_dequantize ? min + dq_scale / 2 : 0.0f; + + const size_t chunks_per_block = + std::is_same<LutType, uint8_t>::value && + (NumCenters ? NumCenters : num_centers) == 16 + ? k16CentersUint8LutChunksPerBlock + : kDefaultChunksPerBlock; + + for (; batch_index + SimdType::size() <= batch_size; + batch_index += SimdType::size()) { + const LutType* batch_lut = lookup_table + batch_index * lut_item_stride; + float* const batch_output = output + batch_index; + for (size_t block_start = 0; block_start < num_chunks; + block_start += chunks_per_block) { + const size_t block_end = + std::min(block_start + chunks_per_block, num_chunks); + const float dq_offset_n = (block_end - block_start) * dq_offset_1; + size_t output_index; + for (output_index = 0; output_index + kUnrollSteps <= num_outputs; + output_index += kUnrollSteps) { + const uint8_t* indices_base = indices + output_index * num_chunks; + size_t chunk_index = block_start; + const LutType* chunk_lut = batch_lut + chunk_index * lut_chunk_stride; + std::array<SimdType, kUnrollSteps> accums; + for (size_t i = 0; i < kUnrollSteps; ++i) { + const size_t center_index = + indices_base[i * num_chunks + chunk_index]; + accums[i].load(chunk_lut + center_index * SimdType::size()); + } + ++chunk_index; + chunk_lut += lut_chunk_stride; + for (; chunk_index < block_end; ++chunk_index) { + for (size_t i = 0; i < kUnrollSteps; ++i) { + SimdType simd; + const size_t center_index = + indices_base[i * num_chunks + chunk_index]; + simd.load(chunk_lut + center_index * SimdType::size()); + accums[i] += simd; + } + chunk_lut += lut_chunk_stride; + } + for (size_t i = 0; i < kUnrollSteps; ++i) { + accums[i].dequantize_accum_storeu( + batch_output + (output_index + i) * batch_size, dq_scale, + dq_offset_n); + } + } + for (; output_index < num_outputs; ++output_index) { + const uint8_t* vector_indices = indices + output_index * num_chunks; + + SimdType accum; + accum.setzero(); + size_t chunk_index = block_start; + const LutType* chunk_lut = batch_lut + chunk_index * lut_chunk_stride; + for (; chunk_index < block_end; ++chunk_index) { + SimdType simd; + simd.load(chunk_lut + vector_indices[chunk_index] * SimdType::size()); + accum += simd; + chunk_lut += lut_chunk_stride; + } + + accum.dequantize_accum_storeu(batch_output + output_index * batch_size, + dq_scale, dq_offset_n); + } + } + } + + return batch_index; +} + +template <typename LutType> +void IndexTableSum(const uint8_t* indices, + size_t num_chunks, + size_t num_outputs, + const LutType* lookup_table, + size_t batch_size, + size_t num_centers, + float min, + float max, + float* const output) { + static_assert(std::is_same<LutType, uint8_t>::value || + std::is_same<LutType, uint16_t>::value, + "Invalid lookup table type."); + std::fill(output, output + batch_size * num_outputs, 0.0f); + size_t i = 0; +#ifdef __AVX2__ + i = IndexTableSumSimdBatch<SimdInt16x16, LutType>( + indices, num_chunks, num_outputs, lookup_table, batch_size, num_centers, + min, max, i, output); +#endif +#ifdef __SSE4_1__ + i = IndexTableSumSimdBatch<SimdInt16x8, LutType>( + indices, num_chunks, num_outputs, lookup_table, batch_size, num_centers, + min, max, i, output); +#endif +#ifdef __ARM_NEON__ + i = IndexTableSumSimdBatch<SimdInt16x8, LutType>( + indices, num_chunks, num_outputs, lookup_table, batch_size, num_centers, + min, max, i, output); +#endif + i = IndexTableSumSimdBatch<SimdInt16x1, LutType>( + indices, num_chunks, num_outputs, lookup_table, batch_size, num_centers, + min, max, i, output); +} + +template <> +inline void IndexTableSum<float>(const uint8_t* indices, + size_t num_chunks, + size_t num_outputs, + const float* lookup_table, + size_t batch_size, + size_t num_centers, + float min, + float max, + float* const output) { + std::fill(output, output + batch_size * num_outputs, 0.0f); + size_t i = 0; +#ifdef __AVX__ + i = IndexTableSumSimdBatch<SimdFloat32x8, float>( + indices, num_chunks, num_outputs, lookup_table, batch_size, num_centers, + min, max, i, output); +#endif +#ifdef __SSE__ + i = IndexTableSumSimdBatch<SimdFloat32x4, float>( + indices, num_chunks, num_outputs, lookup_table, batch_size, num_centers, + min, max, i, output); +#endif +#ifdef __ARM_NEON__ + i = IndexTableSumSimdBatch<SimdFloat32x4, float>( + indices, num_chunks, num_outputs, lookup_table, batch_size, num_centers, + min, max, i, output); +#endif + i = IndexTableSumSimdBatch<SimdFloat32x1, float>( + indices, num_chunks, num_outputs, lookup_table, batch_size, num_centers, + min, max, i, output); +} + +} // namespace core +} // namespace scann_ondevice +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_SCANN_ONDEVICE_CC_CORE_INDEX_TABLE_SUM_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/partitioner.h b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/partitioner.h new file mode 100644 index 0000000..f4e9eb9 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/partitioner.h
@@ -0,0 +1,76 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SUPPORT_SCANN_ONDEVICE_CC_CORE_PARTITIONER_H_ +#define TENSORFLOW_LITE_SUPPORT_SCANN_ONDEVICE_CC_CORE_PARTITIONER_H_ + +#include <utility> + +#include "Eigen/Core" // from @eigen +#include "absl/types/optional.h" // from @com_google_absl +#include "tensorflow_lite_support/scann_ondevice/cc/core/serialized_searcher.pb.h" + +namespace tflite { +namespace scann_ondevice { +namespace core { +class PartitionerInterface { + public: + virtual ~PartitionerInterface() {} + virtual bool Partition(const Eigen::Ref<const Eigen::MatrixXf>& queries, + std::vector<std::vector<int>>* tokens) const = 0; + + virtual int NumPartitions() const = 0; + virtual absl::optional<int> get_vector_dimension() const = 0; +}; +class Partitioner : public PartitionerInterface { + public: + static std::unique_ptr<Partitioner> Create(const PartitionerProto& proto); + bool Partition(const Eigen::Ref<const Eigen::MatrixXf>& queries, + std::vector<std::vector<int>>* tokens) const override; + int NumPartitions() const override; + + inline absl::optional<int> get_vector_dimension() const override { + return leaves_.cols(); + } + + private: + Partitioner(Eigen::MatrixXf leaves, + Eigen::VectorXf leaf_norms, + DistanceMeasure distance) + : leaves_(std::move(leaves)), + leaf_norms_(std::move(leaf_norms)), + distance_(distance) {} + + Eigen::MatrixXf leaves_; + Eigen::VectorXf leaf_norms_; + DistanceMeasure distance_; +}; +class NoOpPartitioner : public PartitionerInterface { + public: + ~NoOpPartitioner() override {} + + bool Partition(const Eigen::Ref<const Eigen::MatrixXf>& queries, + std::vector<std::vector<int>>* tokens) const override; + + int NumPartitions() const override; + inline absl::optional<int> get_vector_dimension() const override { + return absl::optional<int>(); + } +}; + +} // namespace core +} // namespace scann_ondevice +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_SCANN_ONDEVICE_CC_CORE_PARTITIONER_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/processor.h b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/processor.h new file mode 100644 index 0000000..97206f4 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/processor.h
@@ -0,0 +1,101 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SUPPORT_SCANN_ONDEVICE_CC_CORE_PROCESSOR_H_ +#define TENSORFLOW_LITE_SUPPORT_SCANN_ONDEVICE_CC_CORE_PROCESSOR_H_ + +#include <cstdint> +#include <utility> +#include <vector> + +#include "Eigen/Core" // from @eigen +#include "tensorflow_lite_support/scann_ondevice/cc/core/serialized_searcher.pb.h" +#include "tensorflow_lite_support/scann_ondevice/cc/core/top_n_amortized_constant.h" + +namespace tflite { +namespace scann_ondevice { +namespace core { +struct QueryInfo { + template <typename T> + using Matrix = + Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>; + + float fixed_point_min = NAN; + float fixed_point_max = NAN; + float fixed_point_offset = NAN; + float fixed_point_scale = NAN; + + std::shared_ptr<Matrix<float>> query_lut; + std::shared_ptr<Matrix<uint16_t>> query_lut_uint16; + std::shared_ptr<Matrix<uint8_t>> query_lut_uint8; + template <typename T> + std::shared_ptr<Matrix<T>> QueryLUT(); + + std::shared_ptr<Matrix<float>> transposed_query_lut; + std::shared_ptr<Matrix<uint16_t>> transposed_query_lut_uint16; + std::shared_ptr<Matrix<uint8_t>> transposed_query_lut_uint8; + template <typename T> + std::shared_ptr<Matrix<T>> TransposedQueryLUT(); +}; +class PreProcessorInterface { + public: + virtual ~PreProcessorInterface() {} + + virtual bool Process(const Eigen::Ref<const Eigen::MatrixXf>& queries, + QueryInfo* query_info) const = 0; + virtual int num_database_dims() const = 0; + virtual int num_query_dims() const = 0; +}; +class PostProcessorInterface { + public: + virtual ~PostProcessorInterface() {} + + virtual bool Process(std::vector<TopN>* top_n) const = 0; +}; +class AsymmetricHashQuerier : public PreProcessorInterface { + public: + static std::unique_ptr<AsymmetricHashQuerier> Create( + const AsymmetricHashingProto& proto); + bool Process(const Eigen::Ref<const Eigen::MatrixXf>& queries, + QueryInfo* query_info) const override; + + inline int num_database_dims() const override { return codebooks_.size(); } + + inline int num_query_dims() const override { return dims_; } + + private: + AsymmetricHashQuerier(std::vector<Eigen::MatrixXf> codebooks, + std::vector<Eigen::VectorXf> codebook_norms, + DistanceMeasure query_distance, + AsymmetricHashingProto::LookupType lookup_type, + int dims) + : dims_(dims), + lookup_type_(lookup_type), + query_distance_(query_distance), + codebooks_(std::move(codebooks)), + codebook_norms_(std::move(codebook_norms)) {} + void RearrangeLUT(QueryInfo* query_info) const; + + int dims_; + AsymmetricHashingProto::LookupType lookup_type_; + DistanceMeasure query_distance_; + std::vector<Eigen::MatrixXf> codebooks_; + std::vector<Eigen::VectorXf> codebook_norms_; +}; + +} // namespace core +} // namespace scann_ondevice +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_SCANN_ONDEVICE_CC_CORE_PROCESSOR_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/searcher.h b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/searcher.h new file mode 100644 index 0000000..419681b --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/searcher.h
@@ -0,0 +1,256 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SUPPORT_SCANN_ONDEVICE_CC_CORE_SEARCHER_H_ +#define TENSORFLOW_LITE_SUPPORT_SCANN_ONDEVICE_CC_CORE_SEARCHER_H_ + +#include <algorithm> +#include <cstdint> +#include <limits> +#include <utility> +#include <vector> + +#include <glog/logging.h> +#include "Eigen/Core" // from @eigen +#include "absl/types/span.h" // from @com_google_absl +#include "tensorflow_lite_support/cc/port/integral_types.h" +#include "tensorflow_lite_support/scann_ondevice/cc/core/index_table_sum.h" +#include "tensorflow_lite_support/scann_ondevice/cc/core/processor.h" +#include "tensorflow_lite_support/scann_ondevice/cc/core/serialized_searcher.pb.h" +#include "tensorflow_lite_support/scann_ondevice/cc/core/simd_utils.h" +#include "tensorflow_lite_support/scann_ondevice/cc/core/top_n_amortized_constant.h" + +namespace tflite { +namespace scann_ondevice { +namespace core { + +using Matrix8u = + Eigen::Matrix<uint8_t, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>; + +namespace internal { +void ComputeAHDistance(const QueryInfo& query_info, + Eigen::Ref<const Matrix8u> database, + Eigen::Ref<Eigen::MatrixXf> output); + +} +template <class T> +bool AsymmetricHashFindNeighbors(const QueryInfo& query_info, + Eigen::Ref<const Matrix8u> database, + size_t global_offset, + absl::Span<T> topn) { + const int batch_size = query_info.query_lut->cols(); + if (topn.size() != batch_size) { + return false; + } + int database_size = database.cols(); + Eigen::MatrixXf output(batch_size, database_size); + internal::ComputeAHDistance(query_info, database, output); + + for (int i = 0; i < database_size; i++) { + for (int j = 0; j < topn.size(); ++j) { + topn[j].emplace(output(j, i), i + global_offset); + } + } + return true; +} +template <class T> +bool AsymmetricHashFindNeighbors(Eigen::Ref<const Eigen::MatrixXf> queries, + const PreProcessorInterface& preprocessor, + Eigen::Ref<const Matrix8u> database, + size_t global_offset, + absl::Span<T> topn) { + if (queries.cols() != topn.size()) { + return false; + } + QueryInfo query_info; + return preprocessor.Process(queries, &query_info) && + AsymmetricHashFindNeighbors(query_info, database, global_offset, topn); +} +template <class T> +bool FloatFindNeighbors(Eigen::Ref<const Eigen::MatrixXf> queries, + Eigen::Ref<const Eigen::MatrixXf> database, + const size_t global_offset, + const DistanceMeasure distance_measure, + absl::Span<T> topn) { + int query_size = queries.cols(); + int database_size = database.cols(); + Eigen::MatrixXf pairwise_distances(query_size, database_size); + + if (distance_measure == SQUARED_L2_DISTANCE) { + pairwise_distances.colwise() = queries.colwise().squaredNorm().transpose(); + pairwise_distances.rowwise() += database.colwise().squaredNorm(); + pairwise_distances -= 2 * queries.transpose() * database; + } else if (distance_measure == DOT_PRODUCT) { + pairwise_distances = -1 * queries.transpose() * database; + } else { + LOG(ERROR) << "Unsupported distance measure: " + << DistanceMeasure_Name(distance_measure); + return false; + } + + for (int i = 0; i < database_size; ++i) { + for (int j = 0; j < query_size; ++j) { + topn[j].emplace(pairwise_distances(j, i), i + global_offset); + } + } + return true; +} +template <class T> +class SearcherInterfaceT { + public: + virtual ~SearcherInterfaceT() {} + + virtual bool FindNeighbors(const Eigen::Ref<const Eigen::MatrixXf>& queries, + std::vector<T>* topn) const = 0; +}; +template <class T> +class AsymmetricHashLeafSearcherT : public SearcherInterfaceT<T> { + public: + static std::unique_ptr<AsymmetricHashLeafSearcherT<T>> Create( + std::shared_ptr<QueryInfo::Matrix<uint8_t>> database, + int global_offset, + std::shared_ptr<PreProcessorInterface> preprocessor); + static std::unique_ptr<AsymmetricHashLeafSearcherT<T>> Create( + std::shared_ptr<QueryInfo::Matrix<uint8_t>> database, + int global_offset, + std::shared_ptr<PreProcessorInterface> preprocessor, + size_t mini_batch_size); + bool FindNeighbors(const Eigen::Ref<const Eigen::MatrixXf>& queries, + std::vector<T>* topn) const override; + bool FindNeighbors(const QueryInfo& query_info, std::vector<T>* topn) const; + + private: + AsymmetricHashLeafSearcherT( + std::shared_ptr<QueryInfo::Matrix<uint8_t>> database, + int global_offset, + std::shared_ptr<PreProcessorInterface> preprocessor, + size_t mini_batch_size) + : database_(std::move(database)), + global_offset_(global_offset), + preprocessor_(std::move(preprocessor)), + mini_batch_size_(mini_batch_size) {} + std::shared_ptr<QueryInfo::Matrix<uint8_t>> database_ = nullptr; + int global_offset_; + std::shared_ptr<PreProcessorInterface> preprocessor_ = nullptr; + const size_t mini_batch_size_; +}; +template <class T> +class LinearLeafSearcherT : public SearcherInterfaceT<T> { + public: + ~LinearLeafSearcherT() override {} + static std::unique_ptr<LinearLeafSearcherT<T>> Create( + std::shared_ptr<Eigen::MatrixXf> database, + DistanceMeasure distance_measure = SQUARED_L2_DISTANCE, + int global_offset = 0); + + bool FindNeighbors(const Eigen::Ref<const Eigen::MatrixXf>& queries, + std::vector<T>* topn) const override; + + private: + LinearLeafSearcherT(std::shared_ptr<Eigen::MatrixXf> database, + DistanceMeasure distance_measure, + int global_offset) + : database_(std::move(database)), + distance_measure_(distance_measure), + global_offset_(global_offset) {} + + std::shared_ptr<Eigen::MatrixXf> database_ = nullptr; + const DistanceMeasure distance_measure_; + int global_offset_; +}; + +template <class T> +std::unique_ptr<AsymmetricHashLeafSearcherT<T>> +AsymmetricHashLeafSearcherT<T>::Create( + std::shared_ptr<Matrix8u> database, + int global_offset, + std::shared_ptr<PreProcessorInterface> preprocessor) { + return AsymmetricHashLeafSearcherT<T>::Create( + database, global_offset, preprocessor, + std::numeric_limits<size_t>::max()); +} + +template <class T> +std::unique_ptr<AsymmetricHashLeafSearcherT<T>> +AsymmetricHashLeafSearcherT<T>::Create( + std::shared_ptr<Matrix8u> database, + int global_offset, + std::shared_ptr<PreProcessorInterface> preprocessor, + size_t mini_batch_size) { + if (mini_batch_size == 0 || global_offset < 0) { + return nullptr; + } + return std::unique_ptr<AsymmetricHashLeafSearcherT<T>>( + new AsymmetricHashLeafSearcherT<T>(std::move(database), global_offset, + std::move(preprocessor), + mini_batch_size)); +} + +template <class T> +bool AsymmetricHashLeafSearcherT<T>::FindNeighbors( + const Eigen::Ref<const Eigen::MatrixXf>& queries, + std::vector<T>* topn) const { + if (queries.cols() != topn->size()) { + return false; + } + + absl::Span<T> topn_span = absl::MakeSpan(*topn); + for (size_t i = 0; i < queries.cols(); i += mini_batch_size_) { + const size_t num_queries_in_batch = + std::min(mini_batch_size_, queries.cols() - i); + if (!AsymmetricHashFindNeighbors<T>( + queries.middleCols(i, num_queries_in_batch), *preprocessor_, + *database_, global_offset_, + topn_span.subspan(i, num_queries_in_batch))) { + return false; + } + } + return true; +} + +template <class T> +bool AsymmetricHashLeafSearcherT<T>::FindNeighbors(const QueryInfo& query_info, + std::vector<T>* topn) const { + return AsymmetricHashFindNeighbors<T>(query_info, *database_, global_offset_, + absl::MakeSpan(*topn)); +} + +template <class T> +std::unique_ptr<LinearLeafSearcherT<T>> LinearLeafSearcherT<T>::Create( + std::shared_ptr<Eigen::MatrixXf> database, + DistanceMeasure distance_measure, + int global_offset) { + if (global_offset < 0) { + return nullptr; + } + return std::unique_ptr<LinearLeafSearcherT<T>>(new LinearLeafSearcherT<T>( + std::move(database), distance_measure, global_offset)); +} + +template <class T> +bool LinearLeafSearcherT<T>::FindNeighbors( + const Eigen::Ref<const Eigen::MatrixXf>& queries, + std::vector<T>* topn) const { + return FloatFindNeighbors<T>(queries, *database_, global_offset_, + distance_measure_, absl::MakeSpan(*topn)); +} + +using SearcherInterface = SearcherInterfaceT<TopN>; +using AsymmetricHashLeafSearcher = AsymmetricHashLeafSearcherT<TopN>; +using LinearLeafSearcher = LinearLeafSearcherT<TopN>; +} // namespace core +} // namespace scann_ondevice +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_SCANN_ONDEVICE_CC_CORE_SEARCHER_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/searcher_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/searcher_test.cc new file mode 100644 index 0000000..f3931f3 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/searcher_test.cc
@@ -0,0 +1,532 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow_lite_support/scann_ondevice/cc/core/searcher.h" + +#include <algorithm> +#include <cstdint> +#include <limits> +#include <memory> +#include <utility> + +#include <glog/logging.h> +#include "Eigen/Core" // from @eigen +#include "absl/synchronization/mutex.h" // from @com_google_absl +#include "tensorflow_lite_support/cc/port/gmock.h" +#include "tensorflow_lite_support/cc/port/gtest.h" +#include "tensorflow_lite_support/cc/port/integral_types.h" +#include "tensorflow_lite_support/cc/port/proto2.h" +#include "tensorflow_lite_support/scann_ondevice/cc/core/partitioner.h" +#include "tensorflow_lite_support/scann_ondevice/cc/core/processor.h" +#include "tensorflow_lite_support/scann_ondevice/cc/core/serialized_searcher.pb.h" +using TextFormat = ::tflite::support::proto::TextFormat; + +using Eigen::MatrixXf; +using ::testing::ElementsAre; +using ::testing::Pair; +using ::testing::TestWithParam; +using ::testing::Values; +using Matrix8u = + Eigen::Matrix<uint8_t, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>; +using tflite::scann_ondevice::core::TopN; + +const char kExampleAsymmetricHashingProtoString[] = + R"( + subspace: { + entry { + dimension: 0.1; + dimension: 0.2; + } + entry: { + dimension: 0.2; + dimension: 0.1; + } + entry: { + dimension: 0.9; + dimension: 0.8; + } + } + subspace: { + entry { + dimension: -0.1; + dimension: -0.2; + dimension: -0.3; + } + entry: { + dimension: -0.3; + dimension: -0.2; + dimension: -0.1; + } + entry: { + dimension: -0.9; + dimension: -0.8; + dimension: -0.7; + } + })"; + +const char kExamplePartitionerProtoString[] = + R"( + leaf: { + dimension: 0.1; + dimension: 0.2; + } + leaf: { + dimension: 0.2; + dimension: 0.1; + } + leaf: { + dimension: 0.9; + dimension: 0.7; + } + leaf: { + dimension: 0.3; + dimension: 0.3; + })"; +namespace tflite { +namespace scann_ondevice { +namespace core { +namespace { +TEST(PartitionerTest, Partition) { + PartitionerProto proto; + TextFormat::ParseFromString(kExamplePartitionerProtoString, &proto); + proto.set_query_distance(SQUARED_L2_DISTANCE); + auto partitioner = Partitioner::Create(proto); + MatrixXf query(2, 3); + query << 0.3, 0.9, -1, 0.2, 0.9, -1; + + std::vector<std::vector<int>> tokens(3, std::vector<int>(2, -1)); + ASSERT_TRUE(partitioner->Partition(query, &tokens)); + for (int i = 0; i < 3; ++i) { + std::sort(tokens[i].begin(), tokens[i].end()); + } + EXPECT_EQ((std::vector<int>{1, 3}), tokens[0]); + EXPECT_EQ((std::vector<int>{2, 3}), tokens[1]); + EXPECT_EQ((std::vector<int>{0, 1}), tokens[2]); +} + +TEST(PartitionerTest, PartitionDotProductDistance) { + PartitionerProto proto; + TextFormat::ParseFromString(kExamplePartitionerProtoString, &proto); + proto.set_query_distance(DOT_PRODUCT); + auto partitioner = Partitioner::Create(proto); + MatrixXf query(2, 3); + query << 0.3, 0.9, -1, 0.2, 0.9, -1; + + std::vector<std::vector<int>> tokens(3, std::vector<int>(2, -1)); + ASSERT_TRUE(partitioner->Partition(query, &tokens)); + for (int i = 0; i < 3; ++i) { + std::sort(tokens[i].begin(), tokens[i].end()); + } + EXPECT_EQ((std::vector<int>{2, 3}), tokens[0]); + EXPECT_EQ((std::vector<int>{2, 3}), tokens[1]); + EXPECT_EQ((std::vector<int>{0, 1}), tokens[2]); +} + +TEST(ProcessorTest, AsymmetricHashQuerierNonSimd) { + AsymmetricHashingProto proto; + TextFormat::ParseFromString(kExampleAsymmetricHashingProtoString, &proto); + auto querier = AsymmetricHashQuerier::Create(proto); + CHECK(querier); + + MatrixXf query(5, 2); + query << 0, 1, 0, 1, 0, 1, 0, 1, 0, 1; + QueryInfo query_info; + ASSERT_TRUE(querier->Process(query, &query_info)); + MatrixXf expected_lut(6, 2); + expected_lut << 0.05, 1.45, 0.05, 1.45, 1.45, 0.05, 0.14, 4.34, 0.14, 4.34, + 1.94, 9.74; + ASSERT_TRUE(query_info.query_lut->isApprox(expected_lut, 1e-5)); +} + +TEST(ProcessorTest, AsymmetricHashQuerierNonSimdDotProduct) { + AsymmetricHashingProto proto; + TextFormat::ParseFromString(kExampleAsymmetricHashingProtoString, &proto); + proto.set_query_distance(DOT_PRODUCT); + auto querier = AsymmetricHashQuerier::Create(proto); + ASSERT_NE(querier, nullptr); + + MatrixXf query(5, 2); + query << 0, 1, 0, 1, 0, 1, 0, 1, 0, 1; + QueryInfo query_info; + ASSERT_TRUE(querier->Process(query, &query_info)); + + const auto& query_lut = query_info.query_lut; + const float* lut_raw = query_lut->data(); + EXPECT_THAT(std::vector<float>(lut_raw, lut_raw + query_lut->rows()), + ElementsAre(0, 0, 0, 0, 0, 0)); + EXPECT_THAT(std::vector<float>(lut_raw + query_lut->rows(), + lut_raw + query_lut->rows() * 2), + ElementsAre(-0.3, -0.3, -1.7, 0.6, 0.6, 2.4)); +} + +TEST(ProcessorTest, AsymmetricHashQuerierSimd) { + AsymmetricHashingProto proto; + TextFormat::ParseFromString(kExampleAsymmetricHashingProtoString, &proto); + auto querier = AsymmetricHashQuerier::Create(proto); + MatrixXf query(5, 6); + query << 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, + 1, 0, 1, 0, 0, 1, 1; + QueryInfo query_info; + ASSERT_TRUE(querier->Process(query, &query_info)); + MatrixXf expected_lut(6, 6); + expected_lut << 0.05, 0.05, 1.45, 0.65, 0.85, 1.45, 0.05, 0.05, 1.45, 0.85, + 0.65, 1.45, 1.45, 1.45, 0.05, 0.85, 0.65, 0.05, 0.14, 4.34, 0.14, 1.54, + 2.94, 4.34, 0.14, 4.34, 0.14, 1.54, 2.94, 4.34, 1.94, 9.74, 1.94, 4.54, + 7.14, 9.74; + ASSERT_TRUE(query_info.query_lut->isApprox(expected_lut, 1e-5)); + expected_lut << 0.05, 1.45, 0.14, 0.14, 0.85, 1.45, 0.05, 0.85, 4.34, 1.54, + 0.65, 1.45, 1.45, 1.45, 0.14, 1.94, 0.65, 0.05, 0.65, 1.45, 1.54, 9.74, + 2.94, 4.34, 0.05, 0.05, 0.14, 1.94, 2.94, 4.34, 0.05, 0.85, 4.34, 4.54, + 7.14, 9.74; + ASSERT_TRUE(query_info.transposed_query_lut->isApprox(expected_lut, 1e-5)); +} + +TEST(ProcessorTest, AsymmetricHashPreprocessingLazyMemoryAllocation) { + AsymmetricHashingProto proto; + TextFormat::ParseFromString(kExampleAsymmetricHashingProtoString, &proto); + auto querier = AsymmetricHashQuerier::Create(proto); + QueryInfo query_info; + { + MatrixXf query(5, 2); + query << 0, 0, 0, 0, 0, 1, 0, 1, 0, 1; + ASSERT_TRUE(querier->Process(query, &query_info)); + MatrixXf expected_lut(6, 2); + expected_lut << 0.05, 0.05, 0.05, 0.05, 1.45, 1.45, 0.14, 4.34, 0.14, 4.34, + 1.94, 9.74; + EXPECT_TRUE(query_info.query_lut->isApprox(expected_lut, 1e-5)); + EXPECT_TRUE(query_info.transposed_query_lut->isApprox(expected_lut, 1e-5)); + } + { + MatrixXf query(5, 6); + query << 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, + 0, 1, 0, 1, 0, 0, 1, 1; + ASSERT_TRUE(querier->Process(query, &query_info)); + MatrixXf expected_lut(6, 6); + expected_lut << 0.05, 0.05, 1.45, 0.65, 0.85, 1.45, 0.05, 0.05, 1.45, 0.85, + 0.65, 1.45, 1.45, 1.45, 0.05, 0.85, 0.65, 0.05, 0.14, 4.34, 0.14, 1.54, + 2.94, 4.34, 0.14, 4.34, 0.14, 1.54, 2.94, 4.34, 1.94, 9.74, 1.94, 4.54, + 7.14, 9.74; + EXPECT_TRUE(query_info.query_lut->isApprox(expected_lut, 1e-5)); + expected_lut << 0.05, 1.45, 0.14, 0.14, 0.85, 1.45, 0.05, 0.85, 4.34, 1.54, + 0.65, 1.45, 1.45, 1.45, 0.14, 1.94, 0.65, 0.05, 0.65, 1.45, 1.54, 9.74, + 2.94, 4.34, 0.05, 0.05, 0.14, 1.94, 2.94, 4.34, 0.05, 0.85, 4.34, 4.54, + 7.14, 9.74; + EXPECT_TRUE(query_info.transposed_query_lut->isApprox(expected_lut, 1e-5)); + } + { + MatrixXf query(5, 4); + query << 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1; + ASSERT_TRUE(querier->Process(query, &query_info)); + MatrixXf expected_lut(6, 6); + expected_lut << 1.45, 0.65, 0.85, 1.45, 0.85, 1.45, 1.45, 0.85, 0.65, 1.45, + 0.65, 1.45, 0.05, 0.85, 0.65, 0.05, 0.65, 0.05, 0.14, 1.54, 2.94, 4.34, + 2.94, 4.34, 0.14, 1.54, 2.94, 4.34, 2.94, 4.34, 1.94, 4.54, 7.14, 9.74, + 7.14, 9.74; + EXPECT_TRUE(query_info.query_lut->isApprox(expected_lut, 1e-5)); + expected_lut << 1.45, 0.65, 0.14, 2.94, 0.85, 1.45, 0.65, 1.45, 1.54, 4.34, + 0.65, 1.45, 0.85, 0.05, 2.94, 1.94, 0.65, 0.05, 1.45, 0.85, 4.34, 4.54, + 2.94, 4.34, 1.45, 0.65, 0.14, 7.14, 2.94, 4.34, 0.85, 0.05, 1.54, 9.74, + 7.14, 9.74; + EXPECT_TRUE(query_info.transposed_query_lut->isApprox(expected_lut, 1e-5)); + } +} + +TEST(ProcessorTest, AsymmetricHashQuerierUint16) { + AsymmetricHashingProto proto; + TextFormat::ParseFromString(kExampleAsymmetricHashingProtoString, &proto); + proto.set_lookup_type(AsymmetricHashingProto::INT16); + auto querier = AsymmetricHashQuerier::Create(proto); + CHECK(querier); + + MatrixXf query(5, 2); + query << 0, 1, 0, 1, 0, 1, 0, 1, 0, 1; + QueryInfo query_info; + ASSERT_TRUE(querier->Process(query, &query_info)); + QueryInfo::Matrix<uint16_t> expected_lut(6, 2); + expected_lut << 0, 295, 0, 295, 295, 0, 19, 906, 19, 906, 399, 2047; + + LOG(INFO) << *(query_info.query_lut_uint16); + + ASSERT_EQ(*(query_info.query_lut_uint16), expected_lut); + EXPECT_NEAR(query_info.fixed_point_min, 0.05, 1e-4); + EXPECT_NEAR(query_info.fixed_point_max, 9.74, 1e-4); +} + +TEST(ProcessorTest, AsymmetricHashQuerierUint8) { + AsymmetricHashingProto proto; + TextFormat::ParseFromString(kExampleAsymmetricHashingProtoString, &proto); + proto.set_lookup_type(AsymmetricHashingProto::INT8); + auto querier = AsymmetricHashQuerier::Create(proto); + CHECK(querier); + + MatrixXf query(5, 2); + query << 0, 1, 0, 1, 0, 1, 0, 1, 0, 1; + QueryInfo query_info; + ASSERT_TRUE(querier->Process(query, &query_info)); + QueryInfo::Matrix<uint8_t> expected_lut(6, 2); + expected_lut << 0, 36, 0, 36, 36, 0, 2, 112, 2, 112, 49, 255; + ASSERT_EQ(*(query_info.query_lut_uint8), expected_lut); + EXPECT_NEAR(query_info.fixed_point_min, 0.05, 1e-4); + EXPECT_NEAR(query_info.fixed_point_max, 9.74, 1e-4); +} + +class SearcherTest : public TestWithParam<size_t> {}; + +TEST_P(SearcherTest, LinearLeafSearcherNonSimd) { + MatrixXf query(3, 2); + query << 0, 1, 2, 3, 3, 1; + std::shared_ptr<MatrixXf> database(new MatrixXf(3, 5)); + *database << 0, 1, 2, 2, 1, 1, 0, 1, 2, 2, 2, 2, 5, 6, 1; + std::vector<TopN> top_n; + for (int i = 0; i < 2; ++i) { + top_n.emplace_back( + TopN(3, std::make_pair(std::numeric_limits<float>::max(), -1))); + } + auto leaf_searcher = LinearLeafSearcher::Create(database); + ASSERT_TRUE(leaf_searcher->FindNeighbors(query, &top_n)); + + constexpr float kEps = 1e-5; + auto extracted = top_n[0].Take(); + EXPECT_NEAR(2.0, extracted[0].first, kEps); + EXPECT_NEAR(5.0, extracted[1].first, kEps); + EXPECT_NEAR(6.0, extracted[2].first, kEps); + EXPECT_EQ(0, extracted[0].second); + EXPECT_EQ(4, extracted[1].second); + EXPECT_EQ(1, extracted[2].second); + + extracted = top_n[1].Take(); + EXPECT_NEAR(1.0, extracted[0].first, kEps); + EXPECT_NEAR(6.0, extracted[1].first, kEps); + EXPECT_NEAR(10.0, extracted[2].first, kEps); + EXPECT_EQ(4, extracted[0].second); + EXPECT_EQ(0, extracted[1].second); + EXPECT_EQ(1, extracted[2].second); +} + +TEST_P(SearcherTest, LinearLeafSearcherNonSimdDotProduct) { + MatrixXf query(3, 2); + query << 0, 1, 2, 3, 3, 1; + auto database = std::make_shared<MatrixXf>(3, 5); + *database << 0, 1, 2, 2, 1, 1, 0, 1, 2, 2, 2, 2, 5, 6, 1; + + std::vector<TopN> top_n( + 2, TopN(3, std::make_pair(std::numeric_limits<float>::max(), -1))); + + auto leaf_searcher = LinearLeafSearcher::Create(database, DOT_PRODUCT); + ASSERT_TRUE(leaf_searcher->FindNeighbors(query, &top_n)); + + auto extracted = top_n[0].Take(); + EXPECT_THAT(extracted, ElementsAre(Pair(-22, 3), Pair(-17, 2), Pair(-8, 0))); + + extracted = top_n[1].Take(); + EXPECT_THAT(extracted, ElementsAre(Pair(-14, 3), Pair(-10, 2), Pair(-8, 4))); +} + +TEST_P(SearcherTest, AsymmetricHashNonSimd) { + MatrixXf query(5, 2); + query << 0, 1, 0, 1, 0, 1, 0, 1, 0, 1; + std::shared_ptr<Matrix8u> database(new Matrix8u(2, 6)); + *database << 0, 1, 2, 2, 1, 0, 1, 0, 1, 2, 2, 0; + AsymmetricHashingProto proto; + TextFormat::ParseFromString(kExampleAsymmetricHashingProtoString, &proto); + auto querier = AsymmetricHashQuerier::Create(proto); + auto leaf_searcher = + AsymmetricHashLeafSearcher::Create(database, 0, std::move(querier)); + std::vector<TopN> top_n; + for (int i = 0; i < 2; ++i) { + top_n.emplace_back( + TopN(3, std::make_pair(std::numeric_limits<float>::max(), -1))); + } + ASSERT_TRUE(leaf_searcher->FindNeighbors(query, &top_n)); + + constexpr float kEps = 1e-5; + auto extracted = top_n[0].Take(); + EXPECT_NEAR(0.19, extracted[0].first, kEps); + EXPECT_NEAR(0.19, extracted[1].first, kEps); + EXPECT_NEAR(0.19, extracted[2].first, kEps); + + extracted = top_n[1].Take(); + EXPECT_NEAR(4.39, extracted[0].first, kEps); + EXPECT_NEAR(5.79, extracted[1].first, kEps); + EXPECT_NEAR(5.79, extracted[2].first, kEps); +} + +#if defined(__ARM_NEON__) || defined(__SSE__) +TEST_P(SearcherTest, AsymmetricHashSimdFloat32x4) { + MatrixXf query(5, 6); + query << 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, + 1, 0, 1, 0, 0, 1, 1; + std::shared_ptr<Matrix8u> database(new Matrix8u(2, 9)); + *database << 0, 0, 0, 1, 1, 1, 2, 2, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2; + AsymmetricHashingProto proto; + TextFormat::ParseFromString(kExampleAsymmetricHashingProtoString, &proto); + auto querier = AsymmetricHashQuerier::Create(proto); + auto leaf_searcher = AsymmetricHashLeafSearcher::Create( + database, 0, std::move(querier), GetParam()); + std::vector<TopN> top_n; + for (int i = 0; i < 6; ++i) { + top_n.emplace_back( + TopN(3, std::make_pair(std::numeric_limits<float>::max(), -1))); + } + ASSERT_TRUE(leaf_searcher->FindNeighbors(query, &top_n)); + + constexpr float kEps = 1e-5; + auto extracted = top_n[0].Take(); + EXPECT_NEAR(0.19, extracted[0].first, kEps); + EXPECT_NEAR(0.19, extracted[1].first, kEps); + EXPECT_NEAR(0.19, extracted[2].first, kEps); + + extracted = top_n[1].Take(); + EXPECT_NEAR(4.39, extracted[0].first, kEps); + EXPECT_NEAR(4.39, extracted[1].first, kEps); + EXPECT_NEAR(4.39, extracted[2].first, kEps); + + extracted = top_n[2].Take(); + EXPECT_NEAR(0.19, extracted[0].first, kEps); + EXPECT_NEAR(0.19, extracted[1].first, kEps); + EXPECT_NEAR(1.59, extracted[2].first, kEps); + + extracted = top_n[3].Take(); + EXPECT_NEAR(2.19, extracted[0].first, kEps); + EXPECT_NEAR(2.19, extracted[1].first, kEps); + EXPECT_NEAR(2.39, extracted[2].first, kEps); + + extracted = top_n[4].Take(); + EXPECT_NEAR(3.59, extracted[0].first, kEps); + EXPECT_NEAR(3.59, extracted[1].first, kEps); + EXPECT_NEAR(3.59, extracted[2].first, kEps); + + extracted = top_n[5].Take(); + EXPECT_NEAR(4.39, extracted[0].first, kEps); + EXPECT_NEAR(4.39, extracted[1].first, kEps); + EXPECT_NEAR(5.79, extracted[2].first, kEps); +} +#endif + +#if defined(__ARM_NEON__) || defined(__SSE__) +TEST_P(SearcherTest, AsymmetricHashSimdInt16x8) { + MatrixXf query(5, 11); + query << 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, + 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, + 1, 1, 1, 0, 0, 0, 0; + std::shared_ptr<Matrix8u> database(new Matrix8u(2, 9)); + *database << 0, 0, 0, 1, 1, 1, 2, 2, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2; + AsymmetricHashingProto proto; + TextFormat::ParseFromString(kExampleAsymmetricHashingProtoString, &proto); + proto.set_lookup_type(AsymmetricHashingProto::INT16); + auto querier = AsymmetricHashQuerier::Create(proto); + QueryInfo query_info; + ASSERT_TRUE(querier->Process(query, &query_info)); + + auto leaf_searcher = AsymmetricHashLeafSearcher::Create( + database, 0, std::move(querier), GetParam()); + std::vector<TopN> top_n; + for (int i = 0; i < 11; ++i) { + top_n.emplace_back( + TopN(3, std::make_pair(std::numeric_limits<float>::max(), -1))); + } + ASSERT_TRUE(leaf_searcher->FindNeighbors(query, &top_n)); + + auto extracted = top_n[0].Take(); + constexpr float kEps = 5e-2; + EXPECT_NEAR(0.19, extracted[0].first, kEps); + EXPECT_NEAR(0.19, extracted[1].first, kEps); + EXPECT_NEAR(0.19, extracted[2].first, kEps); + + extracted = top_n[1].Take(); + EXPECT_NEAR(4.39, extracted[0].first, kEps); + EXPECT_NEAR(4.39, extracted[1].first, kEps); + EXPECT_NEAR(4.39, extracted[2].first, kEps); + + extracted = top_n[2].Take(); + EXPECT_NEAR(0.19, extracted[0].first, kEps); + EXPECT_NEAR(0.19, extracted[1].first, kEps); + EXPECT_NEAR(1.59, extracted[2].first, kEps); + + extracted = top_n[3].Take(); + EXPECT_NEAR(2.19, extracted[0].first, kEps); + EXPECT_NEAR(2.19, extracted[1].first, kEps); + EXPECT_NEAR(2.39, extracted[2].first, kEps); + + extracted = top_n[4].Take(); + EXPECT_NEAR(3.59, extracted[0].first, kEps); + EXPECT_NEAR(3.59, extracted[1].first, kEps); + EXPECT_NEAR(3.59, extracted[2].first, kEps); + + extracted = top_n[5].Take(); + EXPECT_NEAR(4.39, extracted[0].first, kEps); + EXPECT_NEAR(4.39, extracted[1].first, kEps); + EXPECT_NEAR(5.79, extracted[2].first, kEps); + + extracted = top_n[6].Take(); + EXPECT_NEAR(1.39, extracted[0].first, kEps); + EXPECT_NEAR(1.39, extracted[1].first, kEps); + EXPECT_NEAR(1.79, extracted[2].first, kEps); + + extracted = top_n[7].Take(); + EXPECT_NEAR(1.59, extracted[0].first, kEps); + EXPECT_NEAR(1.59, extracted[1].first, kEps); + EXPECT_NEAR(1.59, extracted[2].first, kEps); + + extracted = top_n[8].Take(); + EXPECT_NEAR(1.39, extracted[0].first, kEps); + EXPECT_NEAR(1.39, extracted[1].first, kEps); + EXPECT_NEAR(1.79, extracted[2].first, kEps); + + extracted = top_n[9].Take(); + EXPECT_NEAR(0.79, extracted[0].first, kEps); + EXPECT_NEAR(0.79, extracted[1].first, kEps); + EXPECT_NEAR(0.99, extracted[2].first, kEps); + + extracted = top_n[10].Take(); + EXPECT_NEAR(0.79, extracted[0].first, kEps); + EXPECT_NEAR(0.79, extracted[1].first, kEps); + EXPECT_NEAR(0.79, extracted[2].first, kEps); +} +#endif + +#if defined(__ARM_NEON__) || defined(__SSE__) +TEST_P(SearcherTest, AsymmetricHashMiniBatchedSimdFail) { + std::shared_ptr<Matrix8u> database(new Matrix8u(2, 9)); + *database << 0, 0, 0, 1, 1, 1, 2, 2, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2; + AsymmetricHashingProto proto; + TextFormat::ParseFromString(kExampleAsymmetricHashingProtoString, &proto); + proto.set_lookup_type(AsymmetricHashingProto::FLOAT); + proto.set_query_distance(DistanceMeasure::UNSPECIFIED); + auto querier = AsymmetricHashQuerier::Create(proto); + auto leaf_searcher = AsymmetricHashLeafSearcher::Create( + database, 0, std::move(querier), GetParam()); + + MatrixXf queries(6, 6); + queries << 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, + 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0; + std::vector<TopN> top_n; + for (int i = 0; i < queries.cols(); ++i) { + top_n.emplace_back( + TopN(3, std::make_pair(std::numeric_limits<float>::max(), -1))); + } + EXPECT_FALSE(leaf_searcher->FindNeighbors(queries, &top_n)); +} +#endif + +INSTANTIATE_TEST_SUITE_P( + SearcherTest, + SearcherTest, + Values(std::numeric_limits<size_t>::max(), 1, 2, 3, 7, 23)); + +} // namespace + +} // namespace core +} // namespace scann_ondevice +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/serialized_searcher.proto b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/serialized_searcher.proto index af0c372..5e0bfa4 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/serialized_searcher.proto +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/core/serialized_searcher.proto
@@ -23,7 +23,9 @@ } message PartitionerProto { - message Leaf { repeated float dimension = 1 [packed = true]; } + message Leaf { + repeated float dimension = 1 [packed = true]; + } repeated Leaf leaf = 1; @@ -33,9 +35,13 @@ } message AsymmetricHashingProto { - message CodebookEntry { repeated float dimension = 1 [packed = true]; } + message CodebookEntry { + repeated float dimension = 1 [packed = true]; + } - message SubspaceCodebook { repeated CodebookEntry entry = 1; } + message SubspaceCodebook { + repeated CodebookEntry entry = 1; + } repeated SubspaceCodebook subspace = 1; @@ -50,7 +56,9 @@ optional LookupType lookup_type = 3 [default = FLOAT]; } message IndexerProto { - oneof indexer { AsymmetricHashingProto asymmetric_hashing = 1; } + oneof indexer { + AsymmetricHashingProto asymmetric_hashing = 1; + } } message ScannOnDeviceConfig { optional DistanceMeasure query_distance = 1 [default = UNSPECIFIED];
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index.cc b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index.cc new file mode 100644 index 0000000..e8be5f6 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index.cc
@@ -0,0 +1,138 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/scann_ondevice/cc/index.h" + +#include <cstddef> +#include <memory> + +#include "absl/memory/memory.h" // from @com_google_absl +#include "absl/status/status.h" // from @com_google_absl +#include "absl/status/statusor.h" // from @com_google_absl +#include "absl/strings/str_format.h" // from @com_google_absl +#include "absl/strings/string_view.h" // from @com_google_absl +#include "leveldb/cache.h" // from @com_google_leveldb +#include "leveldb/iterator.h" // from @com_google_leveldb +#include "leveldb/options.h" // from @com_google_leveldb +#include "leveldb/slice.h" // from @com_google_leveldb +#include "leveldb/status.h" // from @com_google_leveldb +#include "leveldb/table.h" // from @com_google_leveldb +#include "tensorflow_lite_support/cc/port/status_macros.h" +#include "tensorflow_lite_support/scann_ondevice/cc/mem_random_access_file.h" +#include "tensorflow_lite_support/scann_ondevice/cc/utils.h" +#include "tensorflow_lite_support/scann_ondevice/proto/index_config.pb.h" + +namespace tflite { +namespace scann_ondevice { + +namespace { + +// Helper function to get the iterator value associated to the provided key. +// +// Important: the underlying storage for the returned string view is owned by +// the provided iterator, and only valid until this iterator is used again with +// a different key. See: +// https://github.com/google/leveldb/blob/main/include/leveldb/iterator.h +absl::StatusOr<absl::string_view> GetValueForKey(leveldb::Iterator* iterator, + std::string& key) { + iterator->Seek(key); + if (!iterator->Valid() || iterator->key() != key || + !iterator->status().ok()) { + return absl::NotFoundError( + absl::StrFormat("Unable to find key in the index: %s", key)); + } + leveldb::Slice value = iterator->value(); + return absl::string_view(value.data(), value.size()); +} +} // namespace + +/* static */ +absl::StatusOr<std::unique_ptr<Index>> Index::CreateFromIndexBuffer( + const char* buffer_data, + size_t buffer_size) { + // Use absl::WrapUnique() to call private constructor: + // https://abseil.io/tips/126. + std::unique_ptr<Index> index = absl::WrapUnique(new Index()); + RETURN_IF_ERROR(index->InitFromBuffer(buffer_data, buffer_size)); + return index; +} + +absl::StatusOr<IndexConfig> Index::GetIndexConfig() const { + std::string key(kIndexConfigKey); + ASSIGN_OR_RETURN(absl::string_view value, + GetValueForKey(config_iterator_.get(), key)); + IndexConfig config; + if (!config.ParseFromString(std::string(value))) { + return absl::InternalError("Unable to parse IndexConfig proto"); + } + return config; +} + +absl::StatusOr<absl::string_view> Index::GetUserInfo() const { + std::string key(kUserInfoKey); + // Intercept NotFound errors and return empty string instead. + auto user_info_or = GetValueForKey(info_iterator_.get(), key); + if (user_info_or.status().code() == absl::StatusCode::kNotFound) { + return ""; + } + return user_info_or; +} + +absl::StatusOr<absl::string_view> Index::GetPartitionAtIndex(uint32_t i) const { + std::string key(GetPartitionKey(i)); + return GetValueForKey(embedding_iterator_.get(), key); +} + +absl::StatusOr<absl::string_view> Index::GetMetadataAtIndex(uint32_t i) const { + std::string key(GetMetadataKey(i)); + return GetValueForKey(metadata_iterator_.get(), key); +} + +absl::Status Index::InitFromBuffer(const char* buffer_data, + size_t buffer_size) { + // Sanity check. + if (buffer_data == nullptr) { + return absl::InvalidArgumentError("Buffer cannot be null"); + } + // Create file from buffer. + file_ = absl::make_unique<MemRandomAccessFile>(buffer_data, buffer_size); + // Create options with cache disabled, as this saves memory and has negligible + // impact on performance in this setup as any key can be accessed anytime. + leveldb::Options options; + cache_ = absl::WrapUnique(leveldb::NewLRUCache(0)); + options.block_cache = cache_.get(); + // Build Table from file and options. + leveldb::Table* table; + leveldb::Status status = + leveldb::Table::Open(options, file_.get(), buffer_size, &table); + if (!status.ok()) { + return absl::InternalError( + absl::StrFormat("Unable to open levelDB table: %s", status.ToString())); + } + table_ = absl::WrapUnique(table); + // Create iterators. + config_iterator_ = + absl::WrapUnique(table_->NewIterator(leveldb::ReadOptions())); + info_iterator_ = + absl::WrapUnique(table_->NewIterator(leveldb::ReadOptions())); + embedding_iterator_ = + absl::WrapUnique(table_->NewIterator(leveldb::ReadOptions())); + metadata_iterator_ = + absl::WrapUnique(table_->NewIterator(leveldb::ReadOptions())); + return absl::OkStatus(); +} + +} // namespace scann_ondevice +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index.h b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index.h new file mode 100644 index 0000000..15e70918 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index.h
@@ -0,0 +1,91 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_SCANN_ONDEVICE_CC_INDEX_H_ +#define TENSORFLOW_LITE_SUPPORT_SCANN_ONDEVICE_CC_INDEX_H_ + +#include <memory> + +#include "absl/status/status.h" // from @com_google_absl +#include "absl/status/statusor.h" // from @com_google_absl +#include "absl/strings/string_view.h" // from @com_google_absl +#include "leveldb/cache.h" // from @com_google_leveldb +#include "leveldb/iterator.h" // from @com_google_leveldb +#include "leveldb/table.h" // from @com_google_leveldb +#include "tensorflow_lite_support/scann_ondevice/cc/mem_random_access_file.h" +#include "tensorflow_lite_support/scann_ondevice/proto/index_config.pb.h" + +namespace tflite { +namespace scann_ondevice { + +// Helper class for getting access to the data contained in the LevelDB index +// file. +// +// This class is NOT thread-safe. +class Index { + public: + // Creates an Index from the provided buffer. Ownership is transferred to the + // caller. Returns an error if the creation failed, which may happen e.g. if + // the provided buffer is not a valid LevelDB index file. + // + // Warning: Does not take ownership of the provided buffer, which must outlive + // this object. + static absl::StatusOr<std::unique_ptr<Index>> CreateFromIndexBuffer( + const char* buffer_data, + size_t buffer_size); + + // Parses and returns the `IndexConfig` stored in the index file. + absl::StatusOr<IndexConfig> GetIndexConfig() const; + + // Provides access to the opaque user info stored in the index file (if any), + // in raw binary form. Returns an empty string if the index doesn't contain + // user info. + absl::StatusOr<absl::string_view> GetUserInfo() const; + + // Provides access to the partition data corresponding to the i-th leaf in the + // order specified in the `IndexConfig`, in raw binary form. + // + // Warning: In order to avoid unnecessary copies, the underlying pointer for + // the returned string view is only valid until next call to this method. + absl::StatusOr<absl::string_view> GetPartitionAtIndex(uint32_t i) const; + + // Provides access to the metadata associated with the i-th embedding in the + // index, in raw binary form. + // + // Warning: In order to avoid unnecessary copies, the underlying pointer for + // the returned string view is only valid until next call to this method. + absl::StatusOr<absl::string_view> GetMetadataAtIndex(uint32_t i) const; + + private: + // Private default constructor, called from CreateFromBuffer(). + Index() = default; + // Initializes the Index from the provided buffer. + absl::Status InitFromBuffer(const char* buffer_data, size_t buffer_size); + + std::unique_ptr<leveldb::Table> table_; + std::unique_ptr<MemRandomAccessFile> file_; + std::unique_ptr<leveldb::Cache> cache_; + // One iterator per getter, so that calls from one getter don't invalidate + // results from another one. + std::unique_ptr<leveldb::Iterator> config_iterator_; + std::unique_ptr<leveldb::Iterator> info_iterator_; + std::unique_ptr<leveldb::Iterator> embedding_iterator_; + std::unique_ptr<leveldb::Iterator> metadata_iterator_; +}; + +} // namespace scann_ondevice +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_SCANN_ONDEVICE_CC_INDEX_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index_builder.cc b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index_builder.cc new file mode 100644 index 0000000..0d80202 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index_builder.cc
@@ -0,0 +1,180 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/scann_ondevice/cc/index_builder.h" + +#include <cstdint> +#include <string> +#include <tuple> +#include <vector> + +#include "absl/container/btree_map.h" // from @com_google_absl +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/str_format.h" // from @com_google_absl +#include "leveldb/options.h" // from @com_google_leveldb +#include "leveldb/slice.h" // from @com_google_leveldb +#include "leveldb/status.h" // from @com_google_leveldb +#include "leveldb/table_builder.h" // from @com_google_leveldb +#include "leveldb/write_batch.h" // from @com_google_leveldb +#include "tensorflow_lite_support/cc/port/status_macros.h" +#include "tensorflow_lite_support/scann_ondevice/cc/mem_writable_file.h" +#include "tensorflow_lite_support/scann_ondevice/cc/utils.h" +#include "tensorflow_lite_support/scann_ondevice/proto/index_config.pb.h" + +namespace tflite { +namespace scann_ondevice { + +namespace { + +absl::Status LevelDBStatusToAbsl(leveldb::Status leveldb_status) { + if (leveldb_status.ok()) { + return absl::OkStatus(); + } else if (leveldb_status.IsInvalidArgument()) { + return absl::InvalidArgumentError(leveldb_status.ToString()); + } else if (leveldb_status.IsNotFound()) { + return absl::NotFoundError(leveldb_status.ToString()); + } else if (leveldb_status.IsNotSupportedError()) { + return absl::UnimplementedError(leveldb_status.ToString()); + } else { + return absl::InternalError(leveldb_status.ToString()); + } +} + +template <typename T> +absl::StatusOr<std::string> CreateIndexBufferImpl( + absl::Span<const T> database, + absl::optional<absl::Span<const uint32_t>> partition_assignment, + absl::Span<const std::string> metadata, + const std::string& userinfo, + IndexConfig index_config, + bool compression) { + size_t num_partitions = 1; + if (partition_assignment) { + if (partition_assignment->size() != metadata.size()) { + return absl::InvalidArgumentError( + "Size of partition assignment and metadata mismatch"); + } + num_partitions = index_config.scann_config().partitioner().leaf_size(); + } + + if (database.size() / index_config.embedding_dim() != metadata.size()) { + return absl::InvalidArgumentError( + "Number of embeddings differs from number of metadata"); + } + + std::vector<std::vector<char>> partition_bytes(num_partitions); + std::vector<std::vector<std::string>> partition_metadata(num_partitions); + + const size_t per_embedding_bytes = sizeof(T) * index_config.embedding_dim(); + const char* database_bytes = reinterpret_cast<const char*>(database.data()); + for (size_t i = 0; i < metadata.size(); ++i) { + const size_t partition_idx = + partition_assignment ? (*partition_assignment)[i] : 0; + if (partition_idx >= num_partitions) { + return absl::InvalidArgumentError(absl::StrFormat( + "Partition index %d is larger than number of partitions: %d", + partition_idx, num_partitions)); + } + partition_bytes[partition_idx].insert( + partition_bytes[partition_idx].end(), + database_bytes + i * per_embedding_bytes, + database_bytes + (i + 1) * per_embedding_bytes); + partition_metadata[partition_idx].push_back(metadata[i]); + } + + std::vector<std::string> flatten_metadata; + flatten_metadata.reserve(metadata.size()); + for (auto partition : partition_metadata) { + const size_t offset = flatten_metadata.size(); + index_config.mutable_global_partition_offsets()->Add(offset); + flatten_metadata.insert(flatten_metadata.end(), partition.begin(), + partition.end()); + partition.clear(); + partition.shrink_to_fit(); + } + + std::string buffer; + ASSIGN_OR_RETURN(auto mem_writable_file, MemWritableFile::Create(&buffer)); + + leveldb::Options options; + options.compression = + compression ? leveldb::kSnappyCompression : leveldb::kNoCompression; + leveldb::TableBuilder table_builder(options, mem_writable_file.get()); + + // Keys must be added in ascending *lexical* order, e.g: + // E_0, E_1, E_10, E_11, [...], E_18, E_19, E_2, E_20, E_21, [...] + // We're using btree_map to reorder partition and metadata keys. + absl::btree_map<std::string, size_t> ordered_partition_key_to_index; + for (size_t i = 0; i < partition_bytes.size(); ++i) { + ordered_partition_key_to_index[GetPartitionKey(i)] = i; + } + for (auto [key, index] : ordered_partition_key_to_index) { + table_builder.Add(leveldb::Slice(key), + leveldb::Slice(partition_bytes[index].data(), + partition_bytes[index].size())); + } + table_builder.Add(leveldb::Slice(kIndexConfigKey), + leveldb::Slice(index_config.SerializeAsString())); + absl::btree_map<std::string, size_t> ordered_metadata_key_to_index; + for (size_t i = 0; i < flatten_metadata.size(); ++i) { + ordered_metadata_key_to_index[GetMetadataKey(i)] = i; + } + for (auto [key, index] : ordered_metadata_key_to_index) { + table_builder.Add(leveldb::Slice(key), + leveldb::Slice(flatten_metadata[index])); + } + table_builder.Add(leveldb::Slice(kUserInfoKey), leveldb::Slice(userinfo)); + + const auto status = table_builder.Finish(); + if (!status.ok()) { + return LevelDBStatusToAbsl(status); + } + + return buffer; +} + +} // namespace + +absl::StatusOr<std::string> CreateIndexBuffer(const IndexedArtifacts& artifacts, + bool compression) { + if (artifacts.hashed_database.has_value() && + artifacts.float_database.has_value()) { + return absl::InvalidArgumentError( + "Can not have both float database and hashed database"); + } + + IndexConfig index_config; + *index_config.mutable_scann_config() = artifacts.config; + index_config.set_embedding_dim(artifacts.embedding_dim); + if (artifacts.hashed_database.has_value()) { + index_config.set_embedding_type(index_config.UINT8); + return CreateIndexBufferImpl(artifacts.hashed_database.value(), + artifacts.partition_assignment, + artifacts.metadata, artifacts.userinfo, + std::move(index_config), compression); + } else if (artifacts.float_database.has_value()) { + index_config.set_embedding_type(index_config.FLOAT); + return CreateIndexBufferImpl(artifacts.float_database.value(), + artifacts.partition_assignment, + artifacts.metadata, artifacts.userinfo, + std::move(index_config), compression); + } else { + return absl::InvalidArgumentError( + "Need either hashed_database or float_database"); + } +} + +} // namespace scann_ondevice +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index_builder.h b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index_builder.h new file mode 100644 index 0000000..53cac9b --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/index_builder.h
@@ -0,0 +1,69 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_SCANN_ONDEVICE_CC_INDEX_FILE_MUTATOR_H_ +#define TENSORFLOW_LITE_SUPPORT_SCANN_ONDEVICE_CC_INDEX_FILE_MUTATOR_H_ + +#include "absl/status/statusor.h" // from @com_google_absl +#include "absl/strings/string_view.h" // from @com_google_absl +#include "absl/types/optional.h" // from @com_google_absl +#include "absl/types/span.h" // from @com_google_absl +#include "leveldb/db.h" // from @com_google_leveldb +#include "tensorflow_lite_support/scann_ondevice/cc/core/serialized_searcher.pb.h" +#include "tensorflow_lite_support/scann_ondevice/proto/index_config.pb.h" + +namespace tflite { +namespace scann_ondevice { + +struct IndexedArtifacts { + // Config for on-device scam. Contains pretrained parts such as partition + // centroids, compression codebook. + tflite::scann_ondevice::core::ScannOnDeviceConfig config; + + // The dimension of each processed embedding in either hashed_database or + // float_database. Note that if hashing is enabled, it can be different from + // the original embedding dimension depending on the config. + uint32_t embedding_dim; + + // Flattened database embeddings. The embeddings should be stored + // consecutively in row major layout. Exactly one of the hashed_database and + // float_database is expected. hashed_database can be either AH compressed or + // 8-bit quantized. In the case of 8-bit quantization, it's casted to uint8_t. + absl::optional<absl::Span<const uint8_t>> hashed_database; + absl::optional<absl::Span<const float>> float_database; + + // The partition each of the database point belongs to, if the index uses a + // partitioner. The size should be the same as how many database points there + // are. + absl::optional<absl::Span<const uint32_t>> partition_assignment; + + // The metadata (label) for each database point.The size should be the same as + // how many database points there are. + absl::Span<const std::string> metadata; + + // An arbitrary user supplied string for storing custom information. + std::string userinfo; +}; + +// Creates a byte buffer for the index file from the artifacts. Returns errors +// when there are not exactly one database specified, or other issues with input +// such as shape mismatch, invalid partition indices etc. +absl::StatusOr<std::string> CreateIndexBuffer(const IndexedArtifacts& artifacts, + bool compression); + +} // namespace scann_ondevice +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_SCANN_ONDEVICE_CC_INDEX_FILE_MUTATOR_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/mem_random_access_file.cc b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/mem_random_access_file.cc new file mode 100644 index 0000000..59b9deb8 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/mem_random_access_file.cc
@@ -0,0 +1,52 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/scann_ondevice/cc/mem_random_access_file.h" + +#include <algorithm> +#include <cstddef> +#include <cstdint> + +#include "leveldb/env.h" // from @com_google_leveldb +#include "leveldb/slice.h" // from @com_google_leveldb +#include "leveldb/status.h" // from @com_google_leveldb + +namespace tflite { +namespace scann_ondevice { + +MemRandomAccessFile::MemRandomAccessFile(const char* buffer_data, + size_t buffer_size) + : buffer_data_(buffer_data), buffer_size_(buffer_size) {} + +MemRandomAccessFile::~MemRandomAccessFile() {} + +leveldb::Status MemRandomAccessFile::Read(uint64_t offset, + size_t n, + leveldb::Slice* result, + char* scratch) const { + // Sanity check. + if (offset > buffer_size_) { + return leveldb::Status::InvalidArgument( + "Read offset is beyond buffer size"); + } + // Truncate result if the requested chunk extends beyond the buffer. + const size_t result_size = + std::min(n, buffer_size_ - static_cast<size_t>(offset)); + *result = leveldb::Slice(buffer_data_ + offset, result_size); + return leveldb::Status::OK(); +} + +} // namespace scann_ondevice +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/mem_random_access_file.h b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/mem_random_access_file.h new file mode 100644 index 0000000..5ca68f2e2 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/mem_random_access_file.h
@@ -0,0 +1,61 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_SCANN_ONDEVICE_CC_MEM_RANDOM_ACCESS_FILE_H_ +#define TENSORFLOW_LITE_SUPPORT_SCANN_ONDEVICE_CC_MEM_RANDOM_ACCESS_FILE_H_ + +#include <cstddef> +#include <cstdint> + +#include "leveldb/env.h" // from @com_google_leveldb +#include "leveldb/slice.h" // from @com_google_leveldb +#include "leveldb/status.h" // from @com_google_leveldb + +namespace tflite { +namespace scann_ondevice { + +// An implementation of LevelDB's RandomAccessFile [1] that wraps an in-memory +// buffer. +// +// [1]: https://github.com/google/leveldb/blob/main/include/leveldb/env.h +class MemRandomAccessFile : public leveldb::RandomAccessFile { + public: + // Constructor does not take ownership of the provided buffer, which must + // outlive this object. + MemRandomAccessFile(const char* buffer_data, size_t buffer_size); + ~MemRandomAccessFile() override; + + // Override of the `Read` function. Note that `scratch` is unused in the + // implementation. + leveldb::Status Read(uint64_t offset, + size_t n, + leveldb::Slice* result, + char* scratch) const override; + + // Class is movable and non-copyable. + MemRandomAccessFile(MemRandomAccessFile&& rhs) = default; + MemRandomAccessFile& operator=(MemRandomAccessFile&& rhs) = default; + MemRandomAccessFile(const MemRandomAccessFile& rhs) = delete; + MemRandomAccessFile& operator=(const MemRandomAccessFile& rhs) = delete; + + private: + const char* buffer_data_; + size_t buffer_size_; +}; + +} // namespace scann_ondevice +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_SCANN_ONDEVICE_CC_MEM_RANDOM_ACCESS_FILE_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/mem_writable_file.h b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/mem_writable_file.h new file mode 100644 index 0000000..842e83792 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/mem_writable_file.h
@@ -0,0 +1,64 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_SCANN_ONDEVICE_CC_MEM_WRITABLE_FILE_H_ +#define TENSORFLOW_LITE_SUPPORT_SCANN_ONDEVICE_CC_MEM_WRITABLE_FILE_H_ + +#include <memory> +#include <string> + +#include "absl/status/statusor.h" // from @com_google_absl +#include "absl/strings/cord.h" // from @com_google_absl +#include "leveldb/env.h" // from @com_google_leveldb +#include "leveldb/slice.h" // from @com_google_leveldb +#include "leveldb/status.h" // from @com_google_leveldb + +namespace tflite { +namespace scann_ondevice { + +// An implementation of LevelDB's WritableFile [1] that wraps an in-memory +// buffer. +// +// [1]: https://github.com/google/leveldb/blob/main/include/leveldb/env.h +class MemWritableFile : public leveldb::WritableFile { + public: + // Creates a MemWritableFile from a given buffer. Returns + // InvalidArgumentError if pointer is null. + static absl::StatusOr<std::unique_ptr<MemWritableFile>> Create( + std::string* buffer); + + ~MemWritableFile() override = default; + + // Allow moves. Disallow copies. + MemWritableFile(MemWritableFile&& rhs) = default; + MemWritableFile& operator=(MemWritableFile&& rhs) = default; + MemWritableFile(const MemWritableFile& rhs) = delete; + MemWritableFile& operator=(const MemWritableFile& rhs) = delete; + + leveldb::Status Append(const leveldb::Slice& data) override; + leveldb::Status Close() override; + leveldb::Status Flush() override; + leveldb::Status Sync() override; + + private: + MemWritableFile(std::string* buffer); + + std::string* buffer_; +}; + +} // namespace scann_ondevice +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_SCANN_ONDEVICE_CC_MEM_WRITABLE_FILE_H_
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/python/index_builder_py_wrapper.cc b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/python/index_builder_py_wrapper.cc new file mode 100644 index 0000000..7095640 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/python/index_builder_py_wrapper.cc
@@ -0,0 +1,64 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <string> + +#include "absl/types/optional.h" // from @com_google_absl +#include "absl/types/span.h" // from @com_google_absl +#include "pybind11/cast.h" +#include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" +#include "pybind11_abseil/absl_casters.h" // from @pybind11_abseil +#include "pybind11_abseil/status_casters.h" // from @pybind11_abseil +#include "tensorflow_lite_support/scann_ondevice/cc/core/serialized_searcher.pb.h" +#include "tensorflow_lite_support/scann_ondevice/cc/index_builder.h" + +namespace pybind11 { + +PYBIND11_MODULE(index_builder, m) { + google::ImportStatusModule(); + + m.def( + "create_serialized_index_file", + [](const uint32_t embedding_dim, const std::string& serialized_config, + const std::string userinfo, + absl::Span<const uint32_t> partition_assignment, + absl::Span<const std::string> metadata, bool compression, + absl::optional<absl::Span<const uint8_t>> hashed_database, + absl::optional<absl::Span<const float>> float_database) + -> absl::StatusOr<bytes> { + tflite::scann_ondevice::core::ScannOnDeviceConfig config; + config.ParseFromString(serialized_config); + const auto status_or_bytes = tflite::scann_ondevice::CreateIndexBuffer( + {.config = config, + .embedding_dim = embedding_dim, + .hashed_database = hashed_database, + .float_database = float_database, + .partition_assignment = partition_assignment, + .metadata = metadata, + .userinfo = userinfo}, + compression); + if (!status_or_bytes.ok()) { + return status_or_bytes.status(); + } + return bytes(status_or_bytes.value()); + }, + arg("embedding_dim"), arg("serialized_config"), arg("userinfo"), + arg("partition_assignment"), arg("metadata"), arg("compression") = true, + arg("hashed_database") = absl::nullopt, + arg("float_database") = absl::nullopt); +} + +} // namespace pybind11
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/index_builder_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/index_builder_test.cc new file mode 100644 index 0000000..a1af840 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/index_builder_test.cc
@@ -0,0 +1,561 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/scann_ondevice/cc/index_builder.h" + +#include <cstdint> +#include <string> + +#include "absl/flags/flag.h" // from @com_google_absl +#include "absl/memory/memory.h" // from @com_google_absl +#include "absl/status/status.h" // from @com_google_absl +#include "absl/strings/str_format.h" // from @com_google_absl +#include "absl/strings/string_view.h" // from @com_google_absl +#include "absl/types/span.h" // from @com_google_absl +#include "leveldb/env.h" // from @com_google_leveldb +#include "leveldb/iterator.h" // from @com_google_leveldb +#include "leveldb/options.h" // from @com_google_leveldb +#include "leveldb/slice.h" // from @com_google_leveldb +#include "leveldb/status.h" // from @com_google_leveldb +#include "leveldb/table.h" // from @com_google_leveldb +#include "tensorflow_lite_support/cc/port/gmock.h" +#include "tensorflow_lite_support/cc/port/gtest.h" +#include "tensorflow_lite_support/cc/port/status_matchers.h" +#include "tensorflow_lite_support/cc/test/message_matchers.h" +#include "tensorflow_lite_support/cc/test/test_utils.h" +#include "tensorflow_lite_support/scann_ondevice/proto/index_config.pb.h" + +namespace tflite { +namespace scann_ondevice { +namespace { + +using ::testing::Bool; +using ::testing::ElementsAreArray; +using ::testing::TestWithParam; +using ::tflite::support::EqualsProto; +using ::tflite::task::ParseTextProtoOrDie; + +absl::Status SetContents(absl::string_view file_name, + absl::string_view content) { + FILE* fp = fopen(file_name.data(), "w"); + if (fp == NULL) { + return absl::InternalError( + absl::StrFormat("Can't open file: %s", file_name)); + } + + fwrite(content.data(), sizeof(char), content.size(), fp); + size_t write_error = ferror(fp); + if (fclose(fp) != 0 || write_error) { + return absl::InternalError( + absl::StrFormat("Error while writing file: %s. Error message: %s", + file_name, strerror(write_error))); + } + return absl::OkStatus(); +} + +absl::StatusOr<std::string> LookupKey(leveldb::Iterator* iterator, + absl::string_view key) { + iterator->Seek({key.data(), key.size()}); + if (!iterator->Valid() || iterator->key().ToString() != key || + !iterator->status().ok()) { + return absl::NotFoundError("Failed to lookup key"); + } + return iterator->value().ToString(); +} + +constexpr size_t kDimensions = 2; +constexpr size_t kNumEmbeddings = 24; +constexpr size_t kNumPartitions = 12; + +IndexConfig CreateExpectedConfigWithPartitioner( + IndexConfig::Type embedding_type) { + IndexConfig config = ParseTextProtoOrDie<IndexConfig>(R"pb( + scann_config { + partitioner { + leaf { dimension: 0 dimension: 0 } + leaf { dimension: 1 dimension: 1 } + leaf { dimension: 2 dimension: 2 } + leaf { dimension: 3 dimension: 3 } + leaf { dimension: 4 dimension: 4 } + leaf { dimension: 5 dimension: 5 } + leaf { dimension: 6 dimension: 6 } + leaf { dimension: 7 dimension: 7 } + leaf { dimension: 8 dimension: 8 } + leaf { dimension: 9 dimension: 9 } + leaf { dimension: 10 dimension: 10 } + leaf { dimension: 11 dimension: 11 } + } + } + embedding_dim: 2 + embedding_type: UINT8 + global_partition_offsets: 0 + global_partition_offsets: 2 + global_partition_offsets: 4 + global_partition_offsets: 6 + global_partition_offsets: 8 + global_partition_offsets: 10 + global_partition_offsets: 12 + global_partition_offsets: 14 + global_partition_offsets: 16 + global_partition_offsets: 18 + global_partition_offsets: 20 + global_partition_offsets: 22 + )pb"); + config.set_embedding_type(embedding_type); + return config; +} + +IndexConfig CreateExpectedConfigWithoutPartitioner( + IndexConfig::Type embedding_type) { + IndexConfig config = ParseTextProtoOrDie<IndexConfig>(R"pb( + scann_config { query_distance: SQUARED_L2_DISTANCE } + embedding_dim: 2 + global_partition_offsets: 0 + )pb"); + config.set_embedding_type(embedding_type); + return config; +} + +class PopulateIndexFileTest : public TestWithParam<bool /*compression*/> {}; + +TEST_P(PopulateIndexFileTest, WritesHashedDatabaseWithPartitioner) { + const std::string db_path = + tflite::task::JoinPath(getenv("TEST_TMPDIR"), "hashed"); + const bool compression = GetParam(); + + { + tflite::scann_ondevice::core::ScannOnDeviceConfig config = + ParseTextProtoOrDie<tflite::scann_ondevice::core::ScannOnDeviceConfig>( + R"pb( + partitioner: { + leaf { dimension: 0 dimension: 0 } + leaf { dimension: 1 dimension: 1 } + leaf { dimension: 2 dimension: 2 } + leaf { dimension: 3 dimension: 3 } + leaf { dimension: 4 dimension: 4 } + leaf { dimension: 5 dimension: 5 } + leaf { dimension: 6 dimension: 6 } + leaf { dimension: 7 dimension: 7 } + leaf { dimension: 8 dimension: 8 } + leaf { dimension: 9 dimension: 9 } + leaf { dimension: 10 dimension: 10 } + leaf { dimension: 11 dimension: 11 } + } + )pb"); + std::vector<uint8_t> hashed_database; + hashed_database.reserve(kNumEmbeddings * kDimensions); + for (int i = 0; i < kNumEmbeddings; ++i) { + for (int j = 0; j < kDimensions; ++j) { + hashed_database.push_back(i); + } + } + std::vector<uint32_t> partition_assignment; + partition_assignment.reserve(kNumEmbeddings); + for (int i = 0; i < kNumEmbeddings; ++i) { + partition_assignment.push_back(i % kNumPartitions); + } + std::vector<std::string> metadata; + metadata.reserve(kNumEmbeddings); + for (int i = 0; i < kNumEmbeddings; ++i) { + metadata.push_back(absl::StrFormat("%d", i)); + } + SUPPORT_ASSERT_OK_AND_ASSIGN( + const std::string buffer, + CreateIndexBuffer( + {.config = config, + .embedding_dim = kDimensions, + .hashed_database = absl::Span<uint8_t>(hashed_database), + .partition_assignment = absl::Span<uint32_t>(partition_assignment), + .metadata = absl::Span<std::string>(metadata), + .userinfo = "hashed_userinfo"}, + compression)); + SUPPORT_ASSERT_OK(SetContents(db_path, buffer)); + } + + auto* env = leveldb::Env::Default(); + leveldb::RandomAccessFile* hash_file; + size_t hash_file_size; + ASSERT_TRUE(env->NewRandomAccessFile(db_path, &hash_file).ok()); + auto hashed_file_unique = absl::WrapUnique(hash_file); + ASSERT_TRUE(env->GetFileSize(db_path, &hash_file_size).ok()); + + leveldb::Options options; + options.compression = + compression ? leveldb::kSnappyCompression : leveldb::kNoCompression; + + leveldb::Table* hashed_table; + ASSERT_TRUE( + leveldb::Table::Open(options, hash_file, hash_file_size, &hashed_table) + .ok()); + auto hashed_table_unique = absl::WrapUnique(hashed_table); + auto hashed_table_iterator = + absl::WrapUnique(hashed_table->NewIterator(leveldb::ReadOptions())); + + SUPPORT_ASSERT_OK_AND_ASSIGN( + std::string serialized_config, + LookupKey(hashed_table_iterator.get(), "INDEX_CONFIG")); + IndexConfig index_config; + EXPECT_TRUE(index_config.ParseFromString(serialized_config)); + EXPECT_THAT( + index_config, + EqualsProto(CreateExpectedConfigWithPartitioner(IndexConfig::UINT8))); + + SUPPORT_ASSERT_OK_AND_ASSIGN( + std::string userinfo, + LookupKey(hashed_table_iterator.get(), "USER_INFO")); + EXPECT_EQ(userinfo, "hashed_userinfo"); + + // Partition assignment is based on i % kNumPartitions, so: + // * partition 0 contains embeddings 0 and 12, + // * partition 1 contains embeddings 1 and 13, + // * etc + for (int i = 0; i < kNumPartitions; ++i) { + SUPPORT_ASSERT_OK_AND_ASSIGN( + std::string raw_partition_hashed, + LookupKey(hashed_table_iterator.get(), absl::StrFormat("E_%d", i))); + std::vector<char> hashed_partition(raw_partition_hashed.begin(), + raw_partition_hashed.end()); + std::vector<char> expected = {static_cast<char>(i), static_cast<char>(i), + static_cast<char>(i + kNumPartitions), + static_cast<char>(i + kNumPartitions)}; + EXPECT_THAT(hashed_partition, ElementsAreArray(expected)); + } + + // Similarly: + // * metadata 0 contains metadata 0, + // * metadata 1 contains metadata 12, + // * metadata 2 contains metadata 1, + // * metadata 3 contains metadata 13, + // * etc + // Hence the `i / 2 + (i % 2 ? kNumPartitions : 0)` formula here. + for (int i = 0; i < kNumEmbeddings; ++i) { + SUPPORT_ASSERT_OK_AND_ASSIGN( + std::string metadata, + LookupKey(hashed_table_iterator.get(), absl::StrFormat("M_%d", i))); + EXPECT_EQ(metadata, + absl::StrFormat("%d", i / 2 + (i % 2 ? kNumPartitions : 0))); + } +} + +TEST_P(PopulateIndexFileTest, WritesHashedDatabaseWithoutPartitioner) { + const std::string db_path = + tflite::task::JoinPath(getenv("TEST_TMPDIR"), "float"); + const bool compression = GetParam(); + + { + tflite::scann_ondevice::core::ScannOnDeviceConfig config = + ParseTextProtoOrDie<tflite::scann_ondevice::core::ScannOnDeviceConfig>( + R"pb( + query_distance: SQUARED_L2_DISTANCE + )pb"); + std::vector<uint8_t> hashed_database; + hashed_database.reserve(kNumEmbeddings * kDimensions); + for (int i = 0; i < kNumEmbeddings; ++i) { + for (int j = 0; j < kDimensions; ++j) { + hashed_database.push_back(i); + } + } + std::vector<std::string> metadata; + metadata.reserve(kNumEmbeddings); + for (int i = 0; i < kNumEmbeddings; ++i) { + metadata.push_back(absl::StrFormat("%d", i)); + } + SUPPORT_ASSERT_OK_AND_ASSIGN( + const std::string buffer, + CreateIndexBuffer( + {.config = config, + .embedding_dim = kDimensions, + .hashed_database = absl::Span<uint8_t>(hashed_database), + .metadata = absl::Span<std::string>(metadata), + .userinfo = "hashed_userinfo"}, + compression)); + SUPPORT_ASSERT_OK(SetContents(db_path, buffer)); + } + + auto* env = leveldb::Env::Default(); + leveldb::RandomAccessFile* float_file; + size_t float_file_size; + ASSERT_TRUE(env->NewRandomAccessFile(db_path, &float_file).ok()); + auto float_file_unique = absl::WrapUnique(float_file); + ASSERT_TRUE(env->GetFileSize(db_path, &float_file_size).ok()); + + leveldb::Options options; + options.compression = + compression ? leveldb::kSnappyCompression : leveldb::kNoCompression; + + leveldb::Table* float_table; + ASSERT_TRUE( + leveldb::Table::Open(options, float_file, float_file_size, &float_table) + .ok()); + auto float_table_unique = absl::WrapUnique(float_table); + auto float_table_iterator = + absl::WrapUnique(float_table->NewIterator(leveldb::ReadOptions())); + + SUPPORT_ASSERT_OK_AND_ASSIGN( + std::string serialized_config, + LookupKey(float_table_iterator.get(), "INDEX_CONFIG")); + IndexConfig index_config; + EXPECT_TRUE(index_config.ParseFromString(serialized_config)); + EXPECT_THAT( + index_config, + EqualsProto(CreateExpectedConfigWithoutPartitioner(IndexConfig::UINT8))); + + SUPPORT_ASSERT_OK_AND_ASSIGN( + std::string userinfo, LookupKey(float_table_iterator.get(), "USER_INFO")); + EXPECT_EQ(userinfo, "hashed_userinfo"); + + // Check that the unique embedding partition has the exact same contents as + // the database used at construction time. + SUPPORT_ASSERT_OK_AND_ASSIGN(std::string raw_partition_hashed, + LookupKey(float_table_iterator.get(), "E_0")); + std::vector<char> hashed_partition(raw_partition_hashed.begin(), + raw_partition_hashed.end()); + std::vector<char> expected; + expected.reserve(kNumEmbeddings * kDimensions); + for (int i = 0; i < kNumEmbeddings; ++i) { + for (int j = 0; j < kDimensions; ++j) { + expected.push_back(i); + } + } + EXPECT_THAT(hashed_partition, ElementsAreArray(expected)); + + // Check metadata. + for (int i = 0; i < kNumEmbeddings; ++i) { + SUPPORT_ASSERT_OK_AND_ASSIGN( + std::string metadata, + LookupKey(float_table_iterator.get(), absl::StrFormat("M_%d", i))); + EXPECT_EQ(metadata, absl::StrFormat("%d", i)); + } +} + +TEST_P(PopulateIndexFileTest, WritesFloatDatabaseWithPartitioner) { + const std::string db_path = + tflite::task::JoinPath(getenv("TEST_TMPDIR"), "float"); + const bool compression = GetParam(); + + { + tflite::scann_ondevice::core::ScannOnDeviceConfig config = + ParseTextProtoOrDie<tflite::scann_ondevice::core::ScannOnDeviceConfig>( + R"pb( + partitioner: { + leaf { dimension: 0 dimension: 0 } + leaf { dimension: 1 dimension: 1 } + leaf { dimension: 2 dimension: 2 } + leaf { dimension: 3 dimension: 3 } + leaf { dimension: 4 dimension: 4 } + leaf { dimension: 5 dimension: 5 } + leaf { dimension: 6 dimension: 6 } + leaf { dimension: 7 dimension: 7 } + leaf { dimension: 8 dimension: 8 } + leaf { dimension: 9 dimension: 9 } + leaf { dimension: 10 dimension: 10 } + leaf { dimension: 11 dimension: 11 } + } + )pb"); + std::vector<float> float_database; + float_database.reserve(kNumEmbeddings * kDimensions); + for (int i = 0; i < kNumEmbeddings; ++i) { + for (int j = 0; j < kDimensions; ++j) { + float_database.push_back(i); + } + } + std::vector<uint32_t> partition_assignment; + partition_assignment.reserve(kNumEmbeddings); + for (int i = 0; i < kNumEmbeddings; ++i) { + partition_assignment.push_back(i % kNumPartitions); + } + std::vector<std::string> metadata; + metadata.reserve(kNumEmbeddings); + for (int i = 0; i < kNumEmbeddings; ++i) { + metadata.push_back(absl::StrFormat("%d", i)); + } + SUPPORT_ASSERT_OK_AND_ASSIGN( + const std::string buffer, + CreateIndexBuffer( + {.config = config, + .embedding_dim = kDimensions, + .float_database = absl::Span<float>(float_database), + .partition_assignment = absl::Span<uint32_t>(partition_assignment), + .metadata = absl::Span<std::string>(metadata), + .userinfo = "float_userinfo"}, + compression)); + SUPPORT_ASSERT_OK(SetContents(db_path, buffer)); + } + + auto* env = leveldb::Env::Default(); + leveldb::RandomAccessFile* float_file; + size_t float_file_size; + ASSERT_TRUE(env->NewRandomAccessFile(db_path, &float_file).ok()); + auto float_file_unique = absl::WrapUnique(float_file); + ASSERT_TRUE(env->GetFileSize(db_path, &float_file_size).ok()); + + leveldb::Options options; + options.compression = + compression ? leveldb::kSnappyCompression : leveldb::kNoCompression; + + leveldb::Table* float_table; + ASSERT_TRUE( + leveldb::Table::Open(options, float_file, float_file_size, &float_table) + .ok()); + auto float_table_unique = absl::WrapUnique(float_table); + auto float_table_iterator = + absl::WrapUnique(float_table->NewIterator(leveldb::ReadOptions())); + + SUPPORT_ASSERT_OK_AND_ASSIGN( + std::string serialized_config, + LookupKey(float_table_iterator.get(), "INDEX_CONFIG")); + IndexConfig index_config; + EXPECT_TRUE(index_config.ParseFromString(serialized_config)); + EXPECT_THAT( + index_config, + EqualsProto(CreateExpectedConfigWithPartitioner(IndexConfig::FLOAT))); + + SUPPORT_ASSERT_OK_AND_ASSIGN( + std::string userinfo, LookupKey(float_table_iterator.get(), "USER_INFO")); + EXPECT_EQ(userinfo, "float_userinfo"); + + // Partition assignment is based on i % kNumPartitions, so: + // * partition 0 contains embeddings 0 and 12, + // * partition 1 contains embeddings 1 and 13, + // * etc + for (int i = 0; i < kNumPartitions; ++i) { + SUPPORT_ASSERT_OK_AND_ASSIGN( + std::string raw_partition_float, + LookupKey(float_table_iterator.get(), absl::StrFormat("E_%d", i))); + const float* raw_partition_float_ptr = + reinterpret_cast<const float*>(raw_partition_float.data()); + std::vector<float> float_partition( + raw_partition_float_ptr, + raw_partition_float_ptr + raw_partition_float.size() / sizeof(float)); + std::vector<float> expected = {static_cast<float>(i), static_cast<float>(i), + static_cast<float>(i + kNumPartitions), + static_cast<float>(i + kNumPartitions)}; + EXPECT_THAT(float_partition, ElementsAreArray(expected)); + } + + // Similarly: + // * metadata 0 contains metadata 0, + // * metadata 1 contains metadata 12, + // * metadata 2 contains metadata 1, + // * metadata 3 contains metadata 13, + // * etc + // Hence the `i / 2 + (i % 2 ? kNumPartitions : 0)` formula here. + for (int i = 0; i < kNumEmbeddings; ++i) { + SUPPORT_ASSERT_OK_AND_ASSIGN( + std::string metadata, + LookupKey(float_table_iterator.get(), absl::StrFormat("M_%d", i))); + EXPECT_EQ(metadata, + absl::StrFormat("%d", i / 2 + (i % 2 ? kNumPartitions : 0))); + } +} + +TEST_P(PopulateIndexFileTest, WritesFloatDatabaseWithoutPartitioner) { + const std::string db_path = + tflite::task::JoinPath(getenv("TEST_TMPDIR"), "float"); + const bool compression = GetParam(); + + { + tflite::scann_ondevice::core::ScannOnDeviceConfig config = + ParseTextProtoOrDie<tflite::scann_ondevice::core::ScannOnDeviceConfig>( + R"pb( + query_distance: SQUARED_L2_DISTANCE + )pb"); + std::vector<float> float_database; + float_database.reserve(kNumEmbeddings * kDimensions); + for (int i = 0; i < kNumEmbeddings; ++i) { + for (int j = 0; j < kDimensions; ++j) { + float_database.push_back(i); + } + } + std::vector<std::string> metadata; + metadata.reserve(kNumEmbeddings); + for (int i = 0; i < kNumEmbeddings; ++i) { + metadata.push_back(absl::StrFormat("%d", i)); + } + SUPPORT_ASSERT_OK_AND_ASSIGN( + const std::string buffer, + CreateIndexBuffer({.config = config, + .embedding_dim = kDimensions, + .float_database = absl::Span<float>(float_database), + .metadata = absl::Span<std::string>(metadata), + .userinfo = "float_userinfo"}, + compression)); + SUPPORT_ASSERT_OK(SetContents(db_path, buffer)); + } + + auto* env = leveldb::Env::Default(); + leveldb::RandomAccessFile* float_file; + size_t float_file_size; + ASSERT_TRUE(env->NewRandomAccessFile(db_path, &float_file).ok()); + auto float_file_unique = absl::WrapUnique(float_file); + ASSERT_TRUE(env->GetFileSize(db_path, &float_file_size).ok()); + + leveldb::Options options; + options.compression = + compression ? leveldb::kSnappyCompression : leveldb::kNoCompression; + + leveldb::Table* float_table; + ASSERT_TRUE( + leveldb::Table::Open(options, float_file, float_file_size, &float_table) + .ok()); + auto float_table_unique = absl::WrapUnique(float_table); + auto float_table_iterator = + absl::WrapUnique(float_table->NewIterator(leveldb::ReadOptions())); + + SUPPORT_ASSERT_OK_AND_ASSIGN( + std::string serialized_config, + LookupKey(float_table_iterator.get(), "INDEX_CONFIG")); + IndexConfig index_config; + EXPECT_TRUE(index_config.ParseFromString(serialized_config)); + EXPECT_THAT( + index_config, + EqualsProto(CreateExpectedConfigWithoutPartitioner(IndexConfig::FLOAT))); + + SUPPORT_ASSERT_OK_AND_ASSIGN( + std::string userinfo, LookupKey(float_table_iterator.get(), "USER_INFO")); + EXPECT_EQ(userinfo, "float_userinfo"); + + // Check that the unique embedding partition has the exact same contents as + // the database used at construction time. + SUPPORT_ASSERT_OK_AND_ASSIGN(std::string raw_partition_float, + LookupKey(float_table_iterator.get(), "E_0")); + const float* raw_partition_float_ptr = + reinterpret_cast<const float*>(raw_partition_float.data()); + std::vector<float> float_partition( + raw_partition_float_ptr, + raw_partition_float_ptr + raw_partition_float.size() / sizeof(float)); + std::vector<float> expected; + expected.reserve(kNumEmbeddings * kDimensions); + for (int i = 0; i < kNumEmbeddings; ++i) { + for (int j = 0; j < kDimensions; ++j) { + expected.push_back(i); + } + } + EXPECT_THAT(float_partition, ElementsAreArray(expected)); + + // Check metadata. + for (int i = 0; i < kNumEmbeddings; ++i) { + SUPPORT_ASSERT_OK_AND_ASSIGN( + std::string metadata, + LookupKey(float_table_iterator.get(), absl::StrFormat("M_%d", i))); + EXPECT_EQ(metadata, absl::StrFormat("%d", i)); + } +} + +INSTANTIATE_TEST_SUITE_P(PopulateIndexFileTest, PopulateIndexFileTest, Bool()); + +} // namespace +} // namespace scann_ondevice +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/mem_random_access_file_test.cc b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/mem_random_access_file_test.cc new file mode 100644 index 0000000..2d1efb1 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/mem_random_access_file_test.cc
@@ -0,0 +1,64 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/scann_ondevice/cc/mem_random_access_file.h" + +#include "leveldb/slice.h" // from @com_google_leveldb +#include "tensorflow/lite/core/shims/cc/shims_test_util.h" +#include "tensorflow_lite_support/cc/port/gmock.h" +#include "tensorflow_lite_support/cc/port/gtest.h" + +namespace tflite { +namespace scann_ondevice { +namespace { + +constexpr char kBufferData[] = "abcdef"; +constexpr size_t kBufferSize = 6; + +class MemRandomAccessFileTest : public tflite_shims::testing::Test { + public: + MemRandomAccessFileTest() : file_(kBufferData, kBufferSize) {} + + protected: + MemRandomAccessFile file_; + leveldb::Slice result_; +}; + +TEST_F(MemRandomAccessFileTest, ReadFailsWithOutOfBoundsOffset) { + EXPECT_TRUE(file_.Read(/*offset=*/7, /*n=*/1, &result_, /*scratch=*/nullptr) + .IsInvalidArgument()); +} + +TEST_F(MemRandomAccessFileTest, ReadSucceedsWithoutTruncation) { + EXPECT_TRUE( + file_.Read(/*offset=*/1, /*n=*/5, &result_, /*scratch=*/nullptr).ok()); + EXPECT_EQ("bcdef", result_.ToString()); +} + +TEST_F(MemRandomAccessFileTest, ReadSucceedsWithTruncation) { + EXPECT_TRUE( + file_.Read(/*offset=*/1, /*n=*/6, &result_, /*scratch=*/nullptr).ok()); + EXPECT_EQ("bcdef", result_.ToString()); +} + +TEST_F(MemRandomAccessFileTest, ReadSucceedsWithZeroLength) { + EXPECT_TRUE( + file_.Read(/*offset=*/1, /*n=*/0, &result_, /*scratch=*/nullptr).ok()); + EXPECT_EQ("", result_.ToString()); +} + +} // namespace +} // namespace scann_ondevice +} // namespace tflite
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/python/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/python/BUILD index 37ca204..cf0f2ca 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/python/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/python/BUILD
@@ -2,7 +2,7 @@ package( default_visibility = [ - "//tensorflow_lite_support/python/model_maker/core/utils:__subpackages__", + "//tensorflow_lite_support:internal", ], licenses = ["notice"], ) @@ -13,7 +13,7 @@ module_name = "leveldb_testing_utils", visibility = [ "//nlp/sage/learning/asqp:__subpackages__", - "//tensorflow_lite_support/python/model_maker/core/utils:__subpackages__", + "//tensorflow_lite_support:internal", ], deps = [ "@com_google_absl//absl/memory",
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/python/leveldb_testing_utils_py_wrapper.cc b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/python/leveldb_testing_utils_py_wrapper.cc new file mode 100644 index 0000000..1ae7e0c --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/scann_ondevice/cc/test/python/leveldb_testing_utils_py_wrapper.cc
@@ -0,0 +1,77 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <cstdint> +#include <vector> + +#include "absl/memory/memory.h" // from @com_google_absl +#include "absl/status/status.h" // from @com_google_absl +#include "absl/status/statusor.h" // from @com_google_absl +#include "absl/strings/str_format.h" // from @com_google_absl +#include "leveldb/env.h" // from @com_google_leveldb +#include "leveldb/options.h" // from @com_google_leveldb +#include "leveldb/table.h" // from @com_google_leveldb +#include "pybind11/cast.h" +#include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" +#include "pybind11_abseil/absl_casters.h" // from @pybind11_abseil +#include "pybind11_abseil/status_casters.h" // from @pybind11_abseil + +namespace pybind11 { + +PYBIND11_MODULE(leveldb_testing_utils, m) { + google::ImportStatusModule(); + + m.def( + "leveldb_table_to_pair_list", + [](const std::string fname, bool compressed) + -> absl::StatusOr<std::vector<std::pair<bytes, bytes>>> { + auto* env = leveldb::Env::Default(); + leveldb::RandomAccessFile* file; + if (!env->NewRandomAccessFile(fname, &file).ok()) { + return absl::InternalError(absl::StrFormat( + "Failed to create RandomAccessFile at %s", fname)); + } + auto unique_file = absl::WrapUnique(file); + uint64_t file_size; + if (!env->GetFileSize(fname, &file_size).ok()) { + return absl::InternalError( + absl::StrFormat("Failed to get file size at %s", fname)); + } + leveldb::Options options; + options.compression = + compressed ? leveldb::kSnappyCompression : leveldb::kNoCompression; + + leveldb::Table* table; + if (!leveldb::Table::Open(options, file, file_size, &table).ok()) { + return absl::InternalError("Failed to open table"); + } + auto unique_table = absl::WrapUnique(table); + auto table_iterator = + absl::WrapUnique(table->NewIterator(leveldb::ReadOptions())); + table_iterator->SeekToFirst(); + + std::vector<std::pair<bytes, bytes>> result; + for (; table_iterator->Valid(); table_iterator->Next()) { + result.push_back( + std::make_pair(bytes(table_iterator->key().ToString()), + bytes(table_iterator->value().ToString()))); + } + return result; + }, + arg("buffer"), arg("compressed")); +} + +} // namespace pybind11
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/tools/Build_TFLite_Support_Targets.ipynb b/third_party/tflite_support/src/tensorflow_lite_support/tools/Build_TFLite_Support_Targets.ipynb index 3e7e64e..1ce3bf2 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/tools/Build_TFLite_Support_Targets.ipynb +++ b/third_party/tflite_support/src/tensorflow_lite_support/tools/Build_TFLite_Support_Targets.ipynb
@@ -97,7 +97,7 @@ "%env PATH = /usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/tools/node/bin:/tools/google-cloud-sdk/bin:/opt/bin:/android/sdk/tools:/android/sdk/platform-tools:/android/ndk\n", "%env ANDROID_SDK_API_LEVEL=29\n", "%env ANDROID_API_LEVEL=29\n", - "%env ANDROID_BUILD_TOOLS_VERSION=29.0.2\n", + "%env ANDROID_BUILD_TOOLS_VERSION=30.0.0\n", "%env ANDROID_DEV_HOME=/android\n", "%env ANDROID_NDK_API_LEVEL=21\n", "%env ANDROID_NDK_FILENAME=android-ndk-r19c-linux-x86_64.zip\n",
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/tools/build_rules/expand_template.bzl b/third_party/tflite_support/src/tensorflow_lite_support/tools/build_rules/expand_template.bzl deleted file mode 100644 index 717860c..0000000 --- a/third_party/tflite_support/src/tensorflow_lite_support/tools/build_rules/expand_template.bzl +++ /dev/null
@@ -1,50 +0,0 @@ -"""Build macro for libzip.""" - -# forked from kythe/kythe/tools/build_rules/expand_template.bzl -def _expand_template_impl(ctx): - ctx.actions.expand_template( - template = ctx.file.template, - output = ctx.outputs.out, - substitutions = ctx.attr.substitutions, - ) - -expand_template = rule( - attrs = { - "out": attr.output(mandatory = True), - "substitutions": attr.string_dict(mandatory = True), - "template": attr.label( - mandatory = True, - allow_single_file = True, - ), - }, - output_to_genfiles = True, - implementation = _expand_template_impl, -) - -def cmake_substitutions(vars, defines = {}): - """Returns a dict of template substitutions combining `vars` and `defines`. - - Args: - vars: will be turned into a dict replacing `${key}` and `@key@` with `value`. - defines: will be turned into a dict replacing `#cmakedefine` with `#define {value}` - if present is true, otherwise `/* #undef %s /*`. - Returns: - substitutions - """ - subs = {} - for key, value in vars.items(): - subs["${%s}" % (key,)] = str(value) if value != None else "" - subs["@%s@" % (key,)] = str(value) if value != None else "" - - # TODO(shahms): Better handling of #cmakedefine delimiters and line endings to - # avoid the prefix-substitution problem. - # Potentially allow value to be: True, False, None or string. - # True/False => Same as current - # None => assume no suffix value, include \n in sub and replacement - # string => use string to lookup in vars and assume ${} or @@ tail? - for macro, present in defines.items(): - if present: - subs["#cmakedefine %s" % macro] = "#define %s" % macro - else: - subs["#cmakedefine %s" % macro] = "/* #undef %s */" % macro - return subs
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/build_all.sh b/third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/build_all.sh index cb3c2bc..f21a52b8 100755 --- a/third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/build_all.sh +++ b/third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/build_all.sh
@@ -19,10 +19,17 @@ bash tensorflow_lite_support/custom_ops/tf_configure.sh +# Compile the two schema srcjars first. Compiling +# tensorflow-lite-support-metadata-lib directly will lead to a racing issue +# similar to b/200756963 that two sets of TFLite schema source files are +# generated, and will thus cause the duplicated Java class error. +bazel build -c opt --config=monolithic tensorflow_lite_support/metadata:schema_fbs_java_srcjar +bazel build -c opt --config=monolithic tensorflow_lite_support/metadata:metadata_schema_java_srcjar + # Break down metadata builds and avoid the flacky issue when building schema # with Bazel. See b/200756963. bazel build -c opt --config=monolithic \ - //tensorflow_lite_support/metadata/java:tensorflowlite_support_metadata_lib + //tensorflow_lite_support/metadata/java:tensorflow-lite-support-metadata-lib bazel build -c opt --config=monolithic \ //tensorflow_lite_support/metadata/cc:metadata_extractor @@ -51,10 +58,10 @@ //tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio:task-library-audio \ //tensorflow_lite_support/acceleration/configuration:gpu-delegate-plugin -bazel clean -# Coral plugin. -bazel build -c opt ${BAZEL_PARALLEL} --define=darwinn_portable=1 \ - //tensorflow_lite_support/acceleration/configuration:edgetpu_coral_plugin +# Pip package +bazel build -c opt ${BAZEL_PARALLEL} \ + --define darwinn_portable=1 \ + tensorflow_lite_support/tools/pip_package:build_pip_package # Tests.
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/builds/build_ios_framework.sh b/third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/builds/build_ios_framework.sh index 43b37721..ad7663df 100755 --- a/third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/builds/build_ios_framework.sh +++ b/third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/builds/build_ios_framework.sh
@@ -163,6 +163,10 @@ # Copy source files with the intermediate directories preserved. xargs -n1 -I{} rsync -R {} "${TFLS_TMPDIR}" <<< "${SRC_FILES}" + + # Copy the license file to TFLS_TMPDIR + cp "${TFLS_ROOT_DIR}/LICENSE" ${TFLS_TMPDIR} + popd pushd "${TFLS_ROOT_DIR}"
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/common.sh b/third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/common.sh index 135e428..45f6f6fb 100755 --- a/third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/common.sh +++ b/third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/common.sh
@@ -16,7 +16,7 @@ # External `common.sh` # Keep in sync with tensorflow core and configure.py. -LATEST_BAZEL_VERSION=4.2.2 +LATEST_BAZEL_VERSION=5.1.1 # Run flaky functions with retries. # run_with_retry cmd @@ -60,7 +60,7 @@ esac mkdir -p "$HOME/bin" wget --no-verbose -O "$HOME/bin/bazel" \ - "https://github.com/bazelbuild/bazelisk/releases/download/v1.3.0/$name" + "https://github.com/bazelbuild/bazelisk/releases/download/v1.11.0/$name" chmod u+x "$HOME/bin/bazel" if [[ ! ":$PATH:" =~ :"$HOME"/bin/?: ]]; then PATH="$HOME/bin:$PATH"
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/tools/docs/build_py_api_docs.py b/third_party/tflite_support/src/tensorflow_lite_support/tools/docs/build_py_api_docs.py index 6457c80..6649311 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/tools/docs/build_py_api_docs.py +++ b/third_party/tflite_support/src/tensorflow_lite_support/tools/docs/build_py_api_docs.py
@@ -62,21 +62,32 @@ # tflite_support and tensorflow_lite_support. The former is the main # interface, but it imports from the latter, so we need to include it in the # doc scope. - tflite_support_dir = pathlib.Path(tflite_support.__file__).parent + tflite_support_base_dir = pathlib.Path(tflite_support.__file__).parent tensorflow_lite_support_dir = pathlib.Path( tensorflow_lite_support.__file__).parent - # schema_py_generated is a generated API so we can't use annotations to + # Additionally, the tflite_support package is composed of smaller packages + # that live in separate directories. To ensure "view code" URLs work, list + # them explicitly alongside tensorflow_lite_support, which also lives + # somewhere else. + base_dirs = [ + tensorflow_lite_support_dir, + tflite_support_base_dir / 'metadata_writers', + tflite_support_base_dir / 'task'] + code_prefixes = [ + _CODE_PREFIX.value, + f'{_CODE_PREFIX.value}/metadata/python/metadata_writers', + f'{_CODE_PREFIX.value}/python/task'] + + # schema_py_generated is a generated API, so we can't use annotations to # suppress doc generation. del tflite_support.schema_py_generated doc_generator = generate_lib.DocGenerator( root_title='TensorFlow Lite Support', py_modules=[('tflite_support', tflite_support)], - base_dir=[tflite_support_dir, tensorflow_lite_support_dir], - # The two base_dirs have different roots in the GH repo. - code_url_prefix=[_CODE_PREFIX.value + '/metadata/python', - _CODE_PREFIX.value], + base_dir=base_dirs, + code_url_prefix=code_prefixes, search_hints=_SEARCH_HINTS.value, site_path=_SITE_PATH.value, callbacks=[])
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/BUILD index 5f62f9d..2fb06eee 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/BUILD +++ b/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/BUILD
@@ -19,14 +19,22 @@ "//tensorflow_lite_support/metadata/python/metadata_writers:image_segmenter", "//tensorflow_lite_support/metadata/python/metadata_writers:audio_classifier", "//tensorflow_lite_support/metadata/python/metadata_writers:nl_classifier", - # For Model Maker Searcher API to build ScaNN index. - "//tensorflow_lite_support/scann_ondevice/cc/python:index_builder", ] TASK_PIP_DEPS = [ "//tensorflow_lite_support/python/task/vision:image_classifier", "//tensorflow_lite_support/python/task/vision:image_embedder", + "//tensorflow_lite_support/python/task/vision:image_segmenter", + "//tensorflow_lite_support/python/task/vision:image_searcher", "//tensorflow_lite_support/python/task/vision:object_detector", + "//tensorflow_lite_support/python/task/text:text_embedder", + "//tensorflow_lite_support/python/task/text:text_searcher", + "//tensorflow_lite_support/python/task/audio:audio_classifier", + "//tensorflow_lite_support/python/task/audio:audio_embedder", + # For Model Maker Searcher API to build ScaNN index. + "//tensorflow_lite_support/scann_ondevice/cc/python:index_builder", + "//tensorflow_lite_support/scann_ondevice/cc/core:serialized_searcher_py_pb2", + "//tensorflow_lite_support/scann_ondevice/cc/test/python:leveldb_testing_utils", ] filegroup( @@ -67,4 +75,5 @@ srcs = ["simple_console_for_windows.py"], data = COMMON_PIP_DEPS, srcs_version = "PY2AND3", + deps = [], )
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/build_pip_package.sh b/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/build_pip_package.sh index f2b21fc..eaaa2b7 100755 --- a/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/build_pip_package.sh +++ b/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/build_pip_package.sh
@@ -98,9 +98,14 @@ # Task Library is not supported on Windows yet. mkdir ${TMPDIR}/tflite_support/task mkdir ${TMPDIR}/tflite_support/task/core + cp tensorflow_lite_support/tools/pip_package/task.__init__.py ${TMPDIR}/tflite_support/task/__init__.py cp tensorflow_lite_support/tools/pip_package/task_core.__init__.py ${TMPDIR}/tflite_support/task/core/__init__.py mkdir ${TMPDIR}/tflite_support/task/vision cp tensorflow_lite_support/tools/pip_package/task_vision.__init__.py ${TMPDIR}/tflite_support/task/vision/__init__.py + mkdir ${TMPDIR}/tflite_support/task/text + cp tensorflow_lite_support/tools/pip_package/task_text.__init__.py ${TMPDIR}/tflite_support/task/text/__init__.py + mkdir ${TMPDIR}/tflite_support/task/audio + cp tensorflow_lite_support/tools/pip_package/task_audio.__init__.py ${TMPDIR}/tflite_support/task/audio/__init__.py mkdir ${TMPDIR}/tflite_support/task/processor cp tensorflow_lite_support/tools/pip_package/task_processor.__init__.py ${TMPDIR}/tflite_support/task/processor/__init__.py fi
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/rpi/Dockerfile.py3 b/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/rpi/Dockerfile.py3 index a1eb166..2d11a243 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/rpi/Dockerfile.py3 +++ b/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/rpi/Dockerfile.py3
@@ -26,7 +26,8 @@ zlib1g-dev \ curl \ unzip \ - git && \ + git \ + xxd && \ apt-get clean RUN DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC apt-get -y install tzdata @@ -42,13 +43,17 @@ python$PYTHON_VERSION-distutils \ libpython$PYTHON_VERSION-dev \ libpython$PYTHON_VERSION-dev:armhf \ - libpython$PYTHON_VERSION-dev:arm64 + libpython$PYTHON_VERSION-dev:arm64 \ + libusb-1.0-0-dev \ + libusb-1.0-0-dev:armhf \ + libusb-1.0-0-dev:arm64 + RUN ln -sf /usr/bin/python$PYTHON_VERSION /usr/bin/python3 RUN curl -OL https://bootstrap.pypa.io/get-pip.py RUN python3 get-pip.py RUN rm get-pip.py RUN pip3 install --upgrade pip -RUN pip3 install numpy~=1.19.2 setuptools pybind11 +RUN pip3 install numpy~=1.20.0 setuptools pybind11 RUN ln -sf /usr/include/python$PYTHON_VERSION /usr/include/python3 RUN ln -sf /usr/local/lib/python$PYTHON_VERSION/dist-packages/numpy/core/include/numpy /usr/include/python3/numpy RUN ln -sf /usr/bin/python3 /usr/bin/python
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/rpi/build_arm_pip_package.sh b/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/rpi/build_arm_pip_package.sh index 5580f4a..6070cc9 100755 --- a/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/rpi/build_arm_pip_package.sh +++ b/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/rpi/build_arm_pip_package.sh
@@ -18,7 +18,7 @@ NIGHTLY_FLAG=$1 -bazel build -c opt --config=elinux_armhf tensorflow_lite_support/tools/pip_package:build_pip_package +bazel build -c opt --config=elinux_armhf --define darwinn_portable=1 --linkopt=-L/usr/lib/arm-linux-gnueabihf tensorflow_lite_support/tools/pip_package:build_pip_package EXTRA_PKG_NAME_FLAG="--plat-name=manylinux2014-armv7l" ./bazel-bin/tensorflow_lite_support/tools/pip_package/build_pip_package --dst wheels ${NIGHTLY_FLAG} -bazel build -c opt --config=elinux_aarch64 tensorflow_lite_support/tools/pip_package:build_pip_package +bazel build -c opt --config=elinux_aarch64 --define darwinn_portable=1 --linkopt=-L/usr/lib/aarch64-linux-gnu tensorflow_lite_support/tools/pip_package:build_pip_package EXTRA_PKG_NAME_FLAG="--plat-name=manylinux2014-aarch64" ./bazel-bin/tensorflow_lite_support/tools/pip_package/build_pip_package --dst wheels ${NIGHTLY_FLAG}
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/rpi/install_bazel.sh b/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/rpi/install_bazel.sh index 9e3c076..593838e 100755 --- a/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/rpi/install_bazel.sh +++ b/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/rpi/install_bazel.sh
@@ -15,7 +15,7 @@ # ============================================================================== # Select bazel version. -BAZEL_VERSION="4.2.2" +BAZEL_VERSION="5.1.1" set +e local_bazel_ver=$(bazel version 2>&1 | grep -i label | awk '{print $3}')
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/setup.py b/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/setup.py index 59b7ca78..77254d22 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/setup.py +++ b/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/setup.py
@@ -16,12 +16,20 @@ This PyPI package includes the Python bindings for following features: + - Task Library: a set of powerful and easy-to-use task-specific libraries to + integrate TFLite models onto various platforms. See the [Task Library + documentation](https://www.tensorflow.org/lite/inference_with_metadata/task_library/overview) + for more information. - Metadata schemas: wraps TFLite model schema and metadata schema in Python. - - Metadata populator and displayer: can be used to populate the metadata and + - Metadata writer and displayer: can be used to populate the metadata and associated files into the model, as well as converting the populated metadata - into the json format. + into the json format. See the [Metadata + documentation](https://www.tensorflow.org/lite/convert/metadata) for more + information. - Android Codegen tool: generates the Java model interface used in Android for - a particular model. + a particular model. See the [Codegen tool + documentation](https://www.tensorflow.org/lite/inference_with_metadata/codegen) + for more information. """ from __future__ import absolute_import @@ -42,7 +50,7 @@ # This version string is semver compatible, but incompatible with pip. # For pip, we will remove all '-' characters from this string, and use the # result for pip. -_VERSION = '0.3.0' +_VERSION = '0.4.0' SETUP_PACKAGES = [ 'pybind11 >= 2.6.0', @@ -50,11 +58,13 @@ REQUIRED_PACKAGES = [ 'absl-py >= 0.7.0', - 'numpy >= 1.19.2', + 'numpy >= 1.20.0', # TODO(b/187981032): remove the constraint for 2.0 once the incompatibile # issue is resolved. 'flatbuffers >= 1.12, <2', + # The Protobuf version needs to be the same as the one in WORKSPACE. 'protobuf >= 3.18.0', + 'sounddevice >= 0.4.4', ] + SETUP_PACKAGES project_name = 'tflite-support'
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/task.__init__.py b/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/task.__init__.py new file mode 100644 index 0000000..53c8265 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/task.__init__.py
@@ -0,0 +1,32 @@ +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The TensorFlow Lite Task Library. + +TensorFlow Lite Task Library contains a set of powerful and easy-to-use +task-specific libraries for app developers to create ML experiences with +TensorFlow Lite. It provides optimized out-of-box model interfaces for popular +machine learning tasks, such as image and text classification. The model +interfaces are specifically designed for each task to achieve the best +performance and usability. + +Read more in the [Task Library Guide]( +https://tensorflow.org/lite/inference_with_metadata/task_library/overview). +""" + +from . import audio +from . import core +from . import processor +from . import text +from . import vision
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/task_audio.__init__.py b/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/task_audio.__init__.py new file mode 100644 index 0000000..6e161f0f --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/task_audio.__init__.py
@@ -0,0 +1,37 @@ +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TensorFlow Lite Task Library Audio APIs. + +This module provides interface to run TensorFlow Lite audio models. +""" + +from tensorflow_lite_support.python.task.audio import audio_classifier +from tensorflow_lite_support.python.task.audio import audio_embedder +from tensorflow_lite_support.python.task.audio.core import audio_record +from tensorflow_lite_support.python.task.audio.core import tensor_audio + +AudioClassifier = audio_classifier.AudioClassifier +AudioClassifierOptions = audio_classifier.AudioClassifierOptions +AudioEmbedder = audio_embedder.AudioEmbedder +AudioEmbedderOptions = audio_embedder.AudioEmbedderOptions +AudioRecord = audio_record.AudioRecord +AudioFormat = tensor_audio.AudioFormat +TensorAudio = tensor_audio.TensorAudio + +# Remove unnecessary modules to avoid duplication in API docs. +del audio_classifier +del audio_embedder +del audio_record +del tensor_audio \ No newline at end of file
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/task_core.__init__.py b/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/task_core.__init__.py index a934d2a..ef1ef713 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/task_core.__init__.py +++ b/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/task_core.__init__.py
@@ -12,8 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""An import entry for the Task Core module.""" +"""TensorFlow Lite Task Library's core module. + +This module contains classes used across multiple tasks in the Task Library.""" from tensorflow_lite_support.python.task.core.proto import base_options_pb2 BaseOptions = base_options_pb2.BaseOptions + +# Remove unnecessary modules to avoid duplication in API docs. +del base_options_pb2
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/task_processor.__init__.py b/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/task_processor.__init__.py index 7d64552d..a1fc23c 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/task_processor.__init__.py +++ b/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/task_processor.__init__.py
@@ -12,20 +12,57 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""An import entry for the Task Processor module.""" +"""TensorFlow Lite Task Library's processor module. + +This module contains classes related to the pre-processing and post-processing +steps of the Task Library. +""" from tensorflow_lite_support.python.task.processor.proto import bounding_box_pb2 +from tensorflow_lite_support.python.task.processor.proto import class_pb2 from tensorflow_lite_support.python.task.processor.proto import classification_options_pb2 from tensorflow_lite_support.python.task.processor.proto import classifications_pb2 from tensorflow_lite_support.python.task.processor.proto import detection_options_pb2 from tensorflow_lite_support.python.task.processor.proto import detections_pb2 from tensorflow_lite_support.python.task.processor.proto import embedding_options_pb2 from tensorflow_lite_support.python.task.processor.proto import embedding_pb2 +from tensorflow_lite_support.python.task.processor.proto import search_options_pb2 +from tensorflow_lite_support.python.task.processor.proto import search_result_pb2 +from tensorflow_lite_support.python.task.processor.proto import segmentation_options_pb2 +from tensorflow_lite_support.python.task.processor.proto import segmentations_pb2 BoundingBox = bounding_box_pb2.BoundingBox +Category = class_pb2.Category ClassificationOptions = classification_options_pb2.ClassificationOptions Classifications = classifications_pb2.Classifications +ClassificationResult = classifications_pb2.ClassificationResult DetectionOptions = detection_options_pb2.DetectionOptions +Detection = detections_pb2.Detection DetectionResult = detections_pb2.DetectionResult EmbeddingOptions = embedding_options_pb2.EmbeddingOptions +FeatureVector = embedding_pb2.FeatureVector Embedding = embedding_pb2.Embedding +EmbeddingResult = embedding_pb2.EmbeddingResult +SearchOptions = search_options_pb2.SearchOptions +SearchResult = search_result_pb2.SearchResult +NearestNeighbor = search_result_pb2.NearestNeighbor +OutputType = segmentation_options_pb2.OutputType +SegmentationOptions = segmentation_options_pb2.SegmentationOptions +ColoredLabel = segmentations_pb2.ColoredLabel +ConfidenceMask = segmentations_pb2.ConfidenceMask +Segmentation = segmentations_pb2.Segmentation +SegmentationResult = segmentations_pb2.SegmentationResult + +# Remove unnecessary modules to avoid duplication in API docs. +del bounding_box_pb2 +del class_pb2 +del classification_options_pb2 +del classifications_pb2 +del detection_options_pb2 +del detections_pb2 +del embedding_options_pb2 +del embedding_pb2 +del segmentation_options_pb2 +del segmentations_pb2 +del search_options_pb2 +del search_result_pb2
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/task_text.__init__.py b/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/task_text.__init__.py new file mode 100644 index 0000000..90a0810 --- /dev/null +++ b/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/task_text.__init__.py
@@ -0,0 +1,31 @@ +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TensorFlow Lite Task Library Text APIs. + +This module provides interface to run TensorFlow Lite natural language +processing models. +""" + +from tensorflow_lite_support.python.task.text import text_embedder +from tensorflow_lite_support.python.task.text import text_searcher + +TextEmbedder = text_embedder.TextEmbedder +TextEmbedderOptions = text_embedder.TextEmbedderOptions +TextSearcher = text_searcher.TextSearcher +TextSearcherOptions = text_searcher.TextSearcherOptions + +# Remove unnecessary modules to avoid duplication in API docs. +del text_embedder +del text_searcher
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/task_vision.__init__.py b/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/task_vision.__init__.py index e347e120..f36fc90 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/task_vision.__init__.py +++ b/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/task_vision.__init__.py
@@ -12,10 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""An import entry for Task Vision Library.""" +"""TensorFlow Lite Task Library Vision APIs. + +This module provides interface to run TensorFlow Lite computer vision models. +""" from tensorflow_lite_support.python.task.vision import image_classifier from tensorflow_lite_support.python.task.vision import image_embedder +from tensorflow_lite_support.python.task.vision import image_segmenter +from tensorflow_lite_support.python.task.vision import image_searcher from tensorflow_lite_support.python.task.vision import object_detector from tensorflow_lite_support.python.task.vision.core import tensor_image @@ -25,4 +30,16 @@ ObjectDetectorOptions = object_detector.ObjectDetectorOptions ImageEmbedder = image_embedder.ImageEmbedder ImageEmbedderOptions = image_embedder.ImageEmbedderOptions +ImageSegmenter = image_segmenter.ImageSegmenter +ImageSegmenterOptions = image_segmenter.ImageSegmenterOptions +ImageSearcher = image_searcher.ImageSearcher +ImageSearcherOptions = image_searcher.ImageSearcherOptions TensorImage = tensor_image.TensorImage + +# Remove unnecessary modules to avoid duplication in API docs. +del image_classifier +del image_embedder +del image_segmenter +del image_searcher +del object_detector +del tensor_image
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/tflite_support.__init__.py b/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/tflite_support.__init__.py index de41446..e1f5f89d 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/tflite_support.__init__.py +++ b/third_party/tflite_support/src/tensorflow_lite_support/tools/pip_package/tflite_support.__init__.py
@@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""The TFLite Support Library. +"""The TensorFlow Lite Support Library. Install the pip package: @@ -20,9 +20,17 @@ pip install tflite-support ``` +This package provides two major features: +* Metadata writers: add metadata to TensorFlow Lite models. +* Task Library: run TensorFlow Lite models of major machine learning tasks. + To learn more about metadata, flatbuffers and TensorFlow Lite models, check out the [metadata section](https://www.tensorflow.org/lite/convert/metadata) of the -TF Lite guide. +TensorFlow Lite guide. + +To learn more about Task Library, check out the +[documentation](https://www.tensorflow.org/lite/inference_with_metadata/task_library/overview) +on the TensorFlow Lite website. """ # In the original project structure, all python targets are accessed by paths @@ -33,7 +41,13 @@ # In pip build, this file will be renamed as tflite_support/__init__.py. import flatbuffers +import platform + from tensorflow_lite_support.metadata import metadata_schema_py_generated from tensorflow_lite_support.metadata import schema_py_generated from tensorflow_lite_support.metadata.python import metadata from tflite_support import metadata_writers + +if platform.system() != 'Windows': + # Task Library is not supported on Windows yet. + from tflite_support import task
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/tools/zip_files.py b/third_party/tflite_support/src/tensorflow_lite_support/tools/zip_files.py index 9dc66236..d98d074 100644 --- a/third_party/tflite_support/src/tensorflow_lite_support/tools/zip_files.py +++ b/third_party/tflite_support/src/tensorflow_lite_support/tools/zip_files.py
@@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -# Lint as: python3 """Creates a zip package of the files passed in.""" from __future__ import absolute_import from __future__ import division
diff --git a/third_party/tflite_support/src/third_party/libzip.BUILD b/third_party/tflite_support/src/third_party/libzip.BUILD deleted file mode 100644 index 2f75f40..0000000 --- a/third_party/tflite_support/src/third_party/libzip.BUILD +++ /dev/null
@@ -1,189 +0,0 @@ -package( - default_visibility = ["//visibility:public"], -) - -load("@org_tensorflow_lite_support//tensorflow_lite_support/tools/build_rules:expand_template.bzl", "cmake_substitutions", "expand_template") - -_CMAKE_VARIABLES = { - "INT16_T_LIBZIP": 2, - "INT32_T_LIBZIP": 4, - "INT64_T_LIBZIP": 8, - "INT8_T_LIBZIP": 1, - "INT_LIBZIP": 4, - "LIBZIP_TYPES_INCLUDE": "#include <stdint.h>", - "LONG_LIBZIP": 8, - "LONG_LONG_LIBZIP": 8, - "PACKAGE_VERSION": "1.5.1", - "PACKAGE_VERSION_MAJOR": "1", - "PACKAGE_VERSION_MICRO": "1", - "PACKAGE_VERSION_MINOR": "5", - "SHORT_LIBZIP": 2, - "SIZEOF_OFF_T": 8, - "SIZE_T_LIBZIP": 8, - "SSIZE_T_LIBZIP": 8, - "UINT16_T_LIBZIP": 2, - "UINT32_T_LIBZIP": 4, - "UINT64_T_LIBZIP": 8, - "UINT8_T_LIBZIP": 1, - "__INT16_LIBZIP": None, - "__INT32_LIBZIP": None, - "__INT64_LIBZIP": None, - "__INT8_LIBZIP": None, -} - -_CMAKE_VARIABLES.update(dict([ - ( - "ZIP_{sign}INT{size}_T".format( - size = size, - sign = sign.upper(), - ), - "{sign}int{size}_t".format( - size = size, - sign = sign.lower(), - ), - ) - for sign in ("U", "") - for size in (8, 16, 32, 64) -])) - -_SUBSTITUTIONS = { - "@PACKAGE@": "libzip", - "@VERSION@": "1.5.1", # Keep in sync with actual package! -} - -_DEFINES = { - "HAVE_CLONEFILE": False, - "HAVE_COMMONCRYPTO": False, - "HAVE_CRYPTO": False, - "HAVE_DIRENT_H": False, - "HAVE_FICLONERANGE": False, - "HAVE_FILENO": True, - "HAVE_FSEEK": True, - "HAVE_FSEEKO": True, - "HAVE_FTELLO": True, - "HAVE_FTS_H": True, - "HAVE_GETPROGNAME": False, - "HAVE_GNUTLS": False, - "HAVE_LIBBZ2": False, - "HAVE_MKSTEMP": True, - "HAVE_NDIR_H": False, - "HAVE_OPEN": True, - "HAVE_OPENSSL": False, - "HAVE_SETMODE": False, - "HAVE_SHARED": True, - "HAVE_SNPRINTF": True, - "HAVE_SSIZE_T_LIBZIP": True, - "HAVE_STDBOOL_H": True, - "HAVE_STRCASECMP": True, - "HAVE_STRDUP": True, - "HAVE_STRICMP": False, - "HAVE_STRINGS_H": True, - "HAVE_STRTOLL": True, - "HAVE_STRTOULL": True, - "HAVE_STRUCT_TM_TM_ZONE": False, - "HAVE_SYS_DIR_H": False, - "HAVE_SYS_NDIR_H": False, - "HAVE_UNISTD_H": True, - "HAVE__CHMOD": False, - "HAVE__CLOSE": False, - "HAVE__DUP": False, - "HAVE__FDOPEN": False, - "HAVE__FILENO": False, - "HAVE__OPEN": False, - "HAVE__SETMODE": False, - "HAVE__SNPRINTF": False, - "HAVE__STRDUP": False, - "HAVE__STRICMP": False, - "HAVE__STRTOI64": False, - "HAVE__STRTOUI64": False, - "HAVE__UMASK": False, - "HAVE__UNLINK": False, - "HAVE___PROGNAME": False, - "WORDS_BIGENDIAN": False, -} - -_DEFINES.update(dict([( - key, - value != None, -) for key, value in _CMAKE_VARIABLES.items()])) - -_SUBSTITUTIONS.update(cmake_substitutions( - defines = _DEFINES, - vars = _CMAKE_VARIABLES, -)) - -expand_template( - name = "config_h", - out = "config.h", - substitutions = _SUBSTITUTIONS, - template = "cmake-config.h.in", -) - -_VARS = { - "LIBZIP_TYPES_INCLUDE": "#include <stdint.h>", - "PACKAGE_VERSION": "1.5.1", - "PACKAGE_VERSION_MAJOR": "1", - "PACKAGE_VERSION_MICRO": "1", - "PACKAGE_VERSION_MINOR": "5", -} - -_VARS.update(dict([ - ( - "ZIP_{sign}INT{size}_T".format( - size = size, - sign = sign.upper(), - ), - "{sign}int{size}_t".format( - size = size, - sign = sign.lower(), - ), - ) - for sign in ("U", "") - for size in (8, 16, 32, 64) -])) - -expand_template( - name = "zipconf_h", - out = "lib/zipconf.h", - substitutions = cmake_substitutions( - defines = { - "LIBZIP_VERSION": True, - "LIBZIP_VERSION_MAJOR": True, - "LIBZIP_VERSION_MICRO": True, - "LIBZIP_VERSION_MINOR": True, - "ZIP_STATIC": False, - }, - vars = _VARS, - ), - template = "cmake-zipconf.h.in", -) - -cc_library( - name = "zip", - srcs = glob( - [ - "lib/*.c", - "lib/*.h", - ], - exclude = [ - "lib/*win32*", - "lib/zip_random_uwp.c", - "lib/*crypto*", - "lib/*aes*", - "lib/*bzip2*", - ], - ) + [ - "config.h", - ], - hdrs = [ - "lib/zip.h", - "lib/zipconf.h", - ], - copts = [ - "-DHAVE_CONFIG_H", - ], - includes = ["lib"], - deps = [ - "@zlib", - ], -)
diff --git a/third_party/tflite_support/src/third_party/zlib.BUILD b/third_party/tflite_support/src/third_party/zlib.BUILD index 3a93488..6fc09c02 100644 --- a/third_party/tflite_support/src/third_party/zlib.BUILD +++ b/third_party/tflite_support/src/third_party/zlib.BUILD
@@ -42,8 +42,6 @@ name = "zlib_minizip", srcs = [ "contrib/minizip/ioapi.c", - "contrib/minizip/miniunz.c", - "contrib/minizip/minizip.c", "contrib/minizip/unzip.c", "contrib/minizip/zip.c", ],
diff --git a/third_party/tflite_support/src/third_party/zlib.patch b/third_party/tflite_support/src/third_party/zlib.patch new file mode 100644 index 0000000..6304ed8 --- /dev/null +++ b/third_party/tflite_support/src/third_party/zlib.patch
@@ -0,0 +1,62 @@ +diff -ruN a/contrib/minizip/ioapi.h b/contrib/minizip/ioapi.h +--- a/contrib/minizip/ioapi.h ++++ b/contrib/minizip/ioapi.h +@@ -21,7 +21,7 @@ + #ifndef _ZLIBIOAPI64_H + #define _ZLIBIOAPI64_H + +-#if (!defined(_WIN32)) && (!defined(WIN32)) && (!defined(__APPLE__)) ++#if (!defined(_WIN32)) && (!defined(WIN32)) && (!defined(__APPLE__)) && (!defined(__ANDROID__)) + + // Linux needs this to support file operation on files larger then 4+GB + // But might need better if/def to select just the platforms that needs them. +diff -ruN a/contrib/minizip/miniunz.c b/contrib/minizip/miniunz.c +--- a/contrib/minizip/miniunz.c ++++ b/contrib/minizip/miniunz.c +@@ -12,7 +12,7 @@ + Copyright (C) 2009-2010 Mathias Svensson ( http://result42.com ) + */ + +-#if (!defined(_WIN32)) && (!defined(WIN32)) && (!defined(__APPLE__)) ++#if (!defined(_WIN32)) && (!defined(WIN32)) && (!defined(__APPLE__)) && (!defined(__ANDROID__)) + #ifndef __USE_FILE_OFFSET64 + #define __USE_FILE_OFFSET64 + #endif +@@ -27,7 +27,7 @@ + #endif + #endif + +-#ifdef __APPLE__ ++#if defined(__APPLE__) || defined(IOAPI_NO_64) + // In darwin and perhaps other BSD variants off_t is a 64 bit value, hence no need for specific 64 bit functions + #define FOPEN_FUNC(filename, mode) fopen(filename, mode) + #define FTELLO_FUNC(stream) ftello(stream) +@@ -50,6 +50,7 @@ + # include <direct.h> + # include <io.h> + #else ++# include <sys/stat.h> + # include <unistd.h> + # include <utime.h> + #endif +diff -ruN a/contrib/minizip/minizip.c b/contrib/minizip/minizip.c +--- a/contrib/minizip/minizip.c ++++ b/contrib/minizip/minizip.c +@@ -13,7 +13,7 @@ + */ + + +-#if (!defined(_WIN32)) && (!defined(WIN32)) && (!defined(__APPLE__)) ++#if (!defined(_WIN32)) && (!defined(WIN32)) && (!defined(__APPLE__)) && (!defined(__ANDROID__)) + #ifndef __USE_FILE_OFFSET64 + #define __USE_FILE_OFFSET64 + #endif +@@ -28,7 +28,7 @@ + #endif + #endif + +-#ifdef __APPLE__ ++#if defined(__APPLE__) || defined(IOAPI_NO_64) + // In darwin and perhaps other BSD variants off_t is a 64 bit value, hence no need for specific 64 bit functions + #define FOPEN_FUNC(filename, mode) fopen(filename, mode) + #define FTELLO_FUNC(stream) ftello(stream)
diff --git a/tools/binary_size/generate_milestone_reports.py b/tools/binary_size/generate_milestone_reports.py index e4cd3b50..1d32f01 100755 --- a/tools/binary_size/generate_milestone_reports.py +++ b/tools/binary_size/generate_milestone_reports.py
@@ -91,6 +91,7 @@ '100.0.4896.12', '101.0.4951.20', '102.0.5005.37', + '103.0.5060.9', ]
diff --git a/tools/metrics/histograms/enums.xml b/tools/metrics/histograms/enums.xml index 774eb56..bebbf69 100644 --- a/tools/metrics/histograms/enums.xml +++ b/tools/metrics/histograms/enums.xml
@@ -40078,7 +40078,7 @@ <int value="4157" label="V8UDPSocket_RemotePort_AttributeGetter"/> <int value="4158" label="V8UDPSocket_Writable_AttributeGetter"/> <int value="4159" label="AbortSignalTimeout"/> - <int value="4160" label="ClientHintsPartitionedCookies"/> + <int value="4160" label="OBSOLETE_ClientHintsPartitionedCookies"/> <int value="4161" label="V8Document_Prerendering_AttributeGetter"/> <int value="4162" label="V8Document_Onprerenderingchange_AttributeGetter"/> <int value="4163" label="V8Document_Onprerenderingchange_AttributeSetter"/> @@ -40159,6 +40159,11 @@ <int value="4235" label="V8PaymentInstruments_Set_Method"/> <int value="4236" label="PerformanceMeasureFindExistingName"/> <int value="4237" label="FlexboxNewAbsPos"/> + <int value="4238" label="ScriptSchedulingType_Defer"/> + <int value="4239" label="ScriptSchedulingType_ParserBlocking"/> + <int value="4240" label="ScriptSchedulingType_ParserBlockingInline"/> + <int value="4241" label="ScriptSchedulingType_InOrder"/> + <int value="4242" label="ScriptSchedulingType_Async"/> </enum> <enum name="FeaturePolicyAllowlistType"> @@ -40274,7 +40279,7 @@ <int value="93" label="ClientHintUAFullVersionList"/> <int value="94" label="ClientHintUAFull"/> <int value="95" label="ClientHintUAWoW64"/> - <int value="96" label="ClientHintPartitionedCookies"/> + <int value="96" label="Deprecated: ClientHintPartitionedCookies"/> <int value="97" label="BrowsingTopics"/> <int value="98" label="BrowsingTopicsBackwardCompatible"/> <int value="99" label="ClientHintSaveData"/> @@ -57511,6 +57516,7 @@ <int value="-167744090" label="EnableHomeLauncher:enabled"/> <int value="-167420098" label="WebBluetoothNewPermissionsBackend:enabled"/> <int value="-165756594" label="enable-touch-feedback"/> + <int value="-165712979" label="AdaptiveChargingForTesting:enabled"/> <int value="-165006916" label="EnableNeuralPalmDetectionFilter:enabled"/> <int value="-164673139" label="ForceShowContinueSection:disabled"/> <int value="-164539906" @@ -58504,6 +58510,7 @@ <int value="494939785" label="InsecureFormSubmissionInterstitial:disabled"/> <int value="495435958" label="PageInfoAboutThisSiteMoreInfo:enabled"/> <int value="496667708" label="ArcInputOverlay:disabled"/> + <int value="497039057" label="AdaptiveChargingForTesting:disabled"/> <int value="497137719" label="OmniboxVoiceSearchAlwaysVisible:disabled"/> <int value="497150691" label="AssistEmojiEnhanced:disabled"/> <int value="500177932" label="ArcSmartTextSelection:disabled"/> @@ -59194,6 +59201,7 @@ <int value="963457392" label="ChromeHomeModernLayout:disabled"/> <int value="963671232" label="DrawOcclusion:disabled"/> <int value="964613807" label="ContextualSearchTranslationModel:disabled"/> + <int value="964995928" label="LauncherHideContinueSection:enabled"/> <int value="965037619" label="HappinessTrackingSurveysForDesktopSettingsPrivacy:disabled"/> <int value="966134219" label="CrostiniEnableDlc:disabled"/> @@ -59585,6 +59593,7 @@ <int value="1215531732" label="OmniboxUIExperiments:disabled"/> <int value="1215768255" label="AutofillCreditCardLocalCardMigration:enabled"/> <int value="1216286283" label="tint-composited-content"/> + <int value="1216363133" label="LauncherHideContinueSection:disabled"/> <int value="1216452475" label="SyncUSSAutofillProfile:disabled"/> <int value="1216488634" label="NavigationNetworkResponseQueue:enabled"/> <int value="1217907443" label="spurious-power-button-keyboard-accel"/>
diff --git a/tools/metrics/histograms/metadata/android/histograms.xml b/tools/metrics/histograms/metadata/android/histograms.xml index 5d94142..1bbd6fd 100644 --- a/tools/metrics/histograms/metadata/android/histograms.xml +++ b/tools/metrics/histograms/metadata/android/histograms.xml
@@ -772,7 +772,7 @@ </histogram> <histogram name="Android.ContactsPicker.PropertiesRequested" - units="ContactsPickerProperties" expires_after="2020-11-29"> + enum="ContactsPickerProperties" expires_after="2020-11-29"> <owner>finnur@chromium.org</owner> <owner>twellington@chromium.org</owner> <summary> @@ -3013,7 +3013,7 @@ </histogram> <histogram name="Android.SelectFileDialogContentSelected" - units="SelectFileDialogContent" expires_after="2022-10-16"> + enum="SelectFileDialogContent" expires_after="2022-10-16"> <owner>finnur@chromium.org</owner> <owner>peter@chromium.org</owner> <summary> @@ -3320,7 +3320,7 @@ variants="ThumbnailProvider_ClientType"/> </histogram> -<histogram name="Android.Toolbar.BitmapCapture" units="ToolbarCaptureType" +<histogram name="Android.Toolbar.BitmapCapture" enum="ToolbarCaptureType" expires_after="2023-05-11"> <owner>skym@chromium.org</owner> <owner>seacow@google.com</owner> @@ -3333,7 +3333,7 @@ </histogram> <histogram name="Android.TopToolbar.AllowCaptureReason" - units="TopToolbarAllowCaptureReason" expires_after="2023-05-16"> + enum="TopToolbarAllowCaptureReason" expires_after="2023-05-16"> <owner>skym@chromium.org</owner> <owner>seacow@google.com</owner> <summary> @@ -3347,7 +3347,7 @@ </histogram> <histogram name="Android.TopToolbar.BlockCaptureReason" - units="TopToolbarBlockCaptureReason" expires_after="2023-05-16"> + enum="TopToolbarBlockCaptureReason" expires_after="2023-05-16"> <owner>skym@chromium.org</owner> <owner>seacow@google.com</owner> <summary> @@ -3358,7 +3358,7 @@ </histogram> <histogram name="Android.TopToolbar.SnapshotDifference" - units="ToolbarSnapshotDifference" expires_after="2023-05-16"> + enum="ToolbarSnapshotDifference" expires_after="2023-05-16"> <owner>skym@chromium.org</owner> <owner>seacow@google.com</owner> <summary>
diff --git a/tools/metrics/histograms/metadata/apps/histograms.xml b/tools/metrics/histograms/metadata/apps/histograms.xml index 02d4647..4392f72 100644 --- a/tools/metrics/histograms/metadata/apps/histograms.xml +++ b/tools/metrics/histograms/metadata/apps/histograms.xml
@@ -2180,7 +2180,7 @@ </histogram> <histogram name="Apps.LockScreen.AppsProfile.Creation.Success" - units="BooleanSuccess" expires_after="2022-12-01"> + enum="BooleanSuccess" expires_after="2022-12-01"> <owner>glenrob@chromium.org</owner> <owner>tbuckley@chromium.org</owner> <summary>
diff --git a/tools/metrics/histograms/metadata/ash/histograms.xml b/tools/metrics/histograms/metadata/ash/histograms.xml index 3d9361d..f8ab583 100644 --- a/tools/metrics/histograms/metadata/ash/histograms.xml +++ b/tools/metrics/histograms/metadata/ash/histograms.xml
@@ -1853,7 +1853,7 @@ </summary> </histogram> -<histogram name="Ash.DeskTemplate.ReplaceTemplate" units="BooleanHit" +<histogram name="Ash.DeskTemplate.ReplaceTemplate" enum="BooleanHit" expires_after="2022-07-14"> <owner>aprilzhou@chromium.org</owner> <owner>janetmac@chromium.org</owner> @@ -4549,7 +4549,7 @@ </summary> </histogram> -<histogram name="Ash.Window.DragMaximized.Valid" units="Boolean" +<histogram name="Ash.Window.DragMaximized.Valid" enum="Boolean" expires_after="2023-03-31"> <owner>conniekxu@chromium.org</owner> <owner>xdai@chromium.org</owner>
diff --git a/tools/metrics/histograms/metadata/autofill/histograms.xml b/tools/metrics/histograms/metadata/autofill/histograms.xml index ae096afb..6a0411a 100644 --- a/tools/metrics/histograms/metadata/autofill/histograms.xml +++ b/tools/metrics/histograms/metadata/autofill/histograms.xml
@@ -1641,7 +1641,7 @@ </histogram> <histogram name="Autofill.ImageFetcher.RequestLatency" units="ms" - expires_after="2022-07-01"> + expires_after="2023-07-01"> <owner>vishwasuppoor@chromium.org</owner> <owner>siyua@chromium.org</owner> <owner>payments-autofill-team@google.com</owner> @@ -1652,7 +1652,7 @@ </histogram> <histogram name="Autofill.ImageFetcher.Result" enum="BooleanSuccess" - expires_after="2022-11-13"> + expires_after="2023-07-01"> <owner>siyua@chromium.org</owner> <owner>payments-autofill-team@google.com</owner> <summary> @@ -3575,9 +3575,9 @@ </histogram> <histogram name="Autofill.UsedCachedServerCard" units="uses" - expires_after="2022-10-30"> - <owner>jsaul@google.com</owner> + expires_after="2023-06-01"> <owner>siyua@chromium.org</owner> + <owner>jsaul@google.com</owner> <owner>payments-autofill-team@google.com</owner> <summary> Records the number of times that the cache for unmasked server cards has @@ -3588,9 +3588,9 @@ </histogram> <histogram name="Autofill.UsedCachedVirtualCard" units="uses" - expires_after="2022-06-26"> - <owner>jsaul@google.com</owner> + expires_after="2023-06-01"> <owner>siyua@chromium.org</owner> + <owner>jsaul@google.com</owner> <owner>payments-autofill-team@google.com</owner> <summary> Records the number of times that the cache for virtual cards has been @@ -3649,6 +3649,20 @@ <token key="Source" variants="Autofill.VirtualCard.RequestSource"/> </histogram> +<histogram + name="Autofill.VirtualCard.GetDetailsForEnrollment.Latency.{Source}.{Result}" + units="ms" expires_after="2022-08-07"> + <owner>siyua@chromium.org</owner> + <owner>payments-autofill-team@google.com</owner> + <summary> + Records the latency for the GetDetailsForEnrollment roundtrip call. The + timer starts when a GetDetailsForEnrollment request is sent. It is recorded + (the timer stops) when a GetDetailsForEnrollment response is received. + </summary> + <token key="Source" variants="Autofill.VirtualCard.RequestSource"/> + <token key="Result" variants="Autofill.PaymentsRpcResult"/> +</histogram> + <histogram name="Autofill.VirtualCard.GetDetailsForEnrollment.Result.{Source}" enum="BooleanSuccess" expires_after="2022-08-07"> <owner>siyua@chromium.org</owner> @@ -3903,7 +3917,7 @@ </summary> </histogram> -<histogram name="Autofill.WebOTP.OneTimeCode.FromAutocomplete" units="Boolean" +<histogram name="Autofill.WebOTP.OneTimeCode.FromAutocomplete" enum="Boolean" expires_after="2022-12-12"> <owner>yigu@chromium.org</owner> <owner>battre@chromium.org</owner> @@ -3916,7 +3930,7 @@ </histogram> <histogram name="Autofill.WebOTP.PhoneNumberCollection.ParseResult" - units="Boolean" expires_after="2022-12-12"> + enum="Boolean" expires_after="2022-12-12"> <owner>yigu@chromium.org</owner> <owner>battre@chromium.org</owner> <owner>web-identity@google.com</owner>
diff --git a/tools/metrics/histograms/metadata/browser/histograms.xml b/tools/metrics/histograms/metadata/browser/histograms.xml index 2ff7ec63..e9fb33c 100644 --- a/tools/metrics/histograms/metadata/browser/histograms.xml +++ b/tools/metrics/histograms/metadata/browser/histograms.xml
@@ -162,7 +162,7 @@ </histogram> <histogram name="Browser.PaintPreview.Player.CompositorProcessStartedCorrectly" - units="BooleanSuccess" expires_after="2022-11-13"> + enum="BooleanSuccess" expires_after="2022-11-13"> <owner>ckitagawa@chromium.org</owner> <owner>fredmello@chromium.org</owner> <owner>chrome-fdt@google.com</owner> @@ -181,8 +181,8 @@ </summary> </histogram> -<histogram name="Browser.PaintPreview.Player.LinkClicked" - units="BooleanSuccess" expires_after="2022-10-16"> +<histogram name="Browser.PaintPreview.Player.LinkClicked" enum="BooleanSuccess" + expires_after="2022-10-16"> <owner>ckitagawa@chromium.org</owner> <owner>fredmello@chromium.org</owner> <owner>chrome-fdt@google.com</owner> @@ -224,7 +224,7 @@ </histogram> <histogram name="Browser.PaintPreview.TabbedPlayer.FirstPaintBeforeTabLoad" - units="Boolean" expires_after="2022-08-07"> + enum="Boolean" expires_after="2022-08-07"> <owner>ckitagawa@chromium.org</owner> <owner>fredmello@chromium.org</owner> <owner>chrome-fdt@google.com</owner> @@ -234,7 +234,7 @@ </summary> </histogram> -<histogram name="Browser.PaintPreview.TabbedPlayer.HadCapture" units="Boolean" +<histogram name="Browser.PaintPreview.TabbedPlayer.HadCapture" enum="Boolean" expires_after="2022-10-23"> <owner>ckitagawa@chromium.org</owner> <owner>fredmello@chromium.org</owner> @@ -716,7 +716,7 @@ </histogram> <histogram name="BrowserRenderProcessHost.LabeledInTaskManager" - units="BooleanLabeledRendererTask" expires_after="2022-11-06"> + enum="BooleanLabeledRendererTask" expires_after="2022-11-06"> <owner>creis@chromium.org</owner> <owner>avi@chromium.org</owner> <summary>
diff --git a/tools/metrics/histograms/metadata/chromeos/histograms.xml b/tools/metrics/histograms/metadata/chromeos/histograms.xml index 7216ca2..3228f08 100644 --- a/tools/metrics/histograms/metadata/chromeos/histograms.xml +++ b/tools/metrics/histograms/metadata/chromeos/histograms.xml
@@ -683,7 +683,7 @@ </histogram> <histogram name="ChromeOS.DiagnosticsUi.InitialScreen" - units="CrosDiagnosticsNavigationView" expires_after="2022-10-04"> + enum="CrosDiagnosticsNavigationView" expires_after="2022-10-04"> <owner>gavindodd@chromium.org</owner> <owner>zentaro@chromium.org</owner> <owner>cros-peripherals@google.com</owner>
diff --git a/tools/metrics/histograms/metadata/commerce/histograms.xml b/tools/metrics/histograms/metadata/commerce/histograms.xml index 0e61dc7..048eb04 100644 --- a/tools/metrics/histograms/metadata/commerce/histograms.xml +++ b/tools/metrics/histograms/metadata/commerce/histograms.xml
@@ -102,7 +102,7 @@ </histogram> <histogram name="Commerce.Carts.FormSubmitIsTransaction" - units="BooleanIsTransaction" expires_after="2022-11-06"> + enum="BooleanIsTransaction" expires_after="2022-11-06"> <owner>wychen@chromium.org</owner> <owner>yuezhanggg@chromium.org</owner> <owner>chrome-shopping@google.com</owner> @@ -112,7 +112,7 @@ </summary> </histogram> -<histogram name="Commerce.Carts.XHRIsAddToCart" units="BooleanIsAddToCart" +<histogram name="Commerce.Carts.XHRIsAddToCart" enum="BooleanIsAddToCart" expires_after="2022-07-31"> <owner>wychen@chromium.org</owner> <owner>yuezhanggg@chromium.org</owner>
diff --git a/tools/metrics/histograms/metadata/content/histograms.xml b/tools/metrics/histograms/metadata/content/histograms.xml index 930a7c7e..175f96d7 100644 --- a/tools/metrics/histograms/metadata/content/histograms.xml +++ b/tools/metrics/histograms/metadata/content/histograms.xml
@@ -1542,7 +1542,7 @@ </histogram> <histogram name="ContentSuggestions.{FeedType}.InfoCard.{Action}" - units="FeedInfoCardType" expires_after="2023-05-01"> + enum="FeedInfoCardType" expires_after="2023-05-01"> <owner>jianli@chromium.org</owner> <owner>harringtond@chromium.org</owner> <owner>feed@chromium.org</owner>
diff --git a/tools/metrics/histograms/metadata/content_creation/histograms.xml b/tools/metrics/histograms/metadata/content_creation/histograms.xml index d991273..54acfde 100644 --- a/tools/metrics/histograms/metadata/content_creation/histograms.xml +++ b/tools/metrics/histograms/metadata/content_creation/histograms.xml
@@ -110,7 +110,7 @@ </summary> </histogram> -<histogram name="LightweightReactions.GifGenerationCancelled" units="Boolean" +<histogram name="LightweightReactions.GifGenerationCancelled" enum="Boolean" expires_after="2022-07-31"> <owner>gujen@google.com</owner> <owner>chrome-creation@google.com</owner> @@ -476,7 +476,7 @@ </histogram> <histogram name="SharedHighlights.LinkGenerated{Requested}" - units="BooleanSuccess" expires_after="2022-08-10"> + enum="BooleanSuccess" expires_after="2022-08-10"> <owner>gayane@chromium.org</owner> <owner>chrome-shared-highlighting@google.com</owner> <summary> @@ -505,7 +505,7 @@ </histogram> <histogram name="SharedHighlights.ObtainReshareLink.Status" - units="LinkToTextReshareStatus" expires_after="2022-10-09"> + enum="LinkToTextReshareStatus" expires_after="2022-10-09"> <owner>gayane@chromium.org</owner> <owner>chrome-shared-highlighting@google.com</owner> <summary> @@ -631,7 +631,7 @@ </histogram> <histogram name="TextFragmentAnchor{TextFragmentSource}.ListItemMatch" - units="Boolean" expires_after="2022-06-30"> + enum="Boolean" expires_after="2022-06-30"> <obsolete> Removed 2022-05. </obsolete> @@ -728,7 +728,7 @@ </histogram> <histogram name="TextFragmentAnchor{TextFragmentSource}.SpansMultipleBlocks" - units="Boolean" expires_after="2022-06-30"> + enum="Boolean" expires_after="2022-06-30"> <obsolete> Removed 2022-05. </obsolete> @@ -764,7 +764,7 @@ </histogram> <histogram name="TextFragmentAnchor{TextFragmentSource}.TableCellMatch" - units="Boolean" expires_after="2022-06-30"> + enum="Boolean" expires_after="2022-06-30"> <obsolete> Removed 2022-05. </obsolete>
diff --git a/tools/metrics/histograms/metadata/cras/histograms.xml b/tools/metrics/histograms/metadata/cras/histograms.xml index a36849f..eb841f9d 100644 --- a/tools/metrics/histograms/metadata/cras/histograms.xml +++ b/tools/metrics/histograms/metadata/cras/histograms.xml
@@ -253,7 +253,7 @@ </summary> </histogram> -<histogram name="Cras.HfpWidebandSpeechSupported" units="BooleanSupported" +<histogram name="Cras.HfpWidebandSpeechSupported" enum="BooleanSupported" expires_after="2022-11-20"> <owner>hychao@chromium.org</owner> <owner>chromeos-audio@google.com</owner>
diff --git a/tools/metrics/histograms/metadata/direct/histograms.xml b/tools/metrics/histograms/metadata/direct/histograms.xml index 625c947..ce10c32 100644 --- a/tools/metrics/histograms/metadata/direct/histograms.xml +++ b/tools/metrics/histograms/metadata/direct/histograms.xml
@@ -199,7 +199,7 @@ </histogram> <histogram name="DirectWrite.Fonts.Proxy.LookupTableDiskCacheHit" - units="BooleanSuccess" expires_after="2022-10-15"> + enum="BooleanSuccess" expires_after="2022-10-15"> <owner>drott@chromium.org</owner> <owner>layout-dev@chromium.org</owner> <summary> @@ -211,7 +211,7 @@ </histogram> <histogram name="DirectWrite.Fonts.Proxy.LookupTablePersistSuccess" - units="BooleanSuccess" expires_after="2022-10-15"> + enum="BooleanSuccess" expires_after="2022-10-15"> <owner>drott@chromium.org</owner> <owner>layout-dev@chromium.org</owner> <summary>
diff --git a/tools/metrics/histograms/metadata/enterprise/histograms.xml b/tools/metrics/histograms/metadata/enterprise/histograms.xml index 468a8ab..93fac495 100644 --- a/tools/metrics/histograms/metadata/enterprise/histograms.xml +++ b/tools/metrics/histograms/metadata/enterprise/histograms.xml
@@ -617,7 +617,7 @@ </histogram> <histogram name="Enterprise.DeviceSettings.MissingPolicyMitigated" - units="BooleanSuccess" expires_after="2022-10-04"> + enum="BooleanSuccess" expires_after="2022-10-04"> <owner>rbock@google.com</owner> <owner>managed-devices@google.com</owner> <summary>
diff --git a/tools/metrics/histograms/metadata/extensions/histograms.xml b/tools/metrics/histograms/metadata/extensions/histograms.xml index 0f28b99..8ad0bc5 100644 --- a/tools/metrics/histograms/metadata/extensions/histograms.xml +++ b/tools/metrics/histograms/metadata/extensions/histograms.xml
@@ -233,7 +233,7 @@ </histogram> <histogram name="Extensions.BackgroundPageType" - units="ExtensionBackgroundPageType" expires_after="never"> + enum="ExtensionBackgroundPageType" expires_after="never"> <!-- expires-never: Used for monitoring user extension usage. --> <owner>rdevlin.cronin@chromium.org</owner> @@ -3174,8 +3174,8 @@ </summary> </histogram> -<histogram name="Extensions.SyncBlockedByDefaultWebAppMigration" - units="Boolean" expires_after="2022-11-13"> +<histogram name="Extensions.SyncBlockedByDefaultWebAppMigration" enum="Boolean" + expires_after="2022-11-13"> <owner>alancutter@chromium.org</owner> <owner>extensions-core@chromium.org</owner> <summary>
diff --git a/tools/metrics/histograms/metadata/gpu/histograms.xml b/tools/metrics/histograms/metadata/gpu/histograms.xml index 3fe63c9..ae3fcdb 100644 --- a/tools/metrics/histograms/metadata/gpu/histograms.xml +++ b/tools/metrics/histograms/metadata/gpu/histograms.xml
@@ -1828,7 +1828,7 @@ </histogram> <histogram name="Viz.FrameSinkVideoCapturer.CaptureSucceeded" - units="BooleanSuccess" expires_after="2023-03-01"> + enum="BooleanSuccess" expires_after="2023-03-01"> <owner>bialpio@chromium.org</owner> <owner>media-capture-dev@chromium.org</owner> <summary> @@ -1839,7 +1839,7 @@ </summary> </histogram> -<histogram name="Viz.FrameSinkVideoCapturer.FrameResurrected" units="Boolean" +<histogram name="Viz.FrameSinkVideoCapturer.FrameResurrected" enum="Boolean" expires_after="2023-03-01"> <owner>bialpio@chromium.org</owner> <owner>media-capture-dev@chromium.org</owner>
diff --git a/tools/metrics/histograms/metadata/history/histograms.xml b/tools/metrics/histograms/metadata/history/histograms.xml index 54a02a6..a89f5343 100644 --- a/tools/metrics/histograms/metadata/history/histograms.xml +++ b/tools/metrics/histograms/metadata/history/histograms.xml
@@ -324,8 +324,8 @@ </summary> </histogram> -<histogram name="History.ClearBrowsingData.InstalledAppExcluded" - units="Boolean" expires_after="2023-03-09"> +<histogram name="History.ClearBrowsingData.InstalledAppExcluded" enum="Boolean" + expires_after="2023-03-09"> <owner>ayui@chromium.org</owner> <owner>jarrydg@chromium.org</owner> <owner>chrome-owp-storage@google.com</owner> @@ -352,7 +352,7 @@ </histogram> <histogram name="History.ClearBrowsingData.InstalledAppsDialogShown" - units="Boolean" expires_after="2023-04-19"> + enum="Boolean" expires_after="2023-04-19"> <owner>ayui@chromium.org</owner> <owner>jarrydg@chromium.org</owner> <owner>chrome-owp-storage@google.com</owner>
diff --git a/tools/metrics/histograms/metadata/ios/histograms.xml b/tools/metrics/histograms/metadata/ios/histograms.xml index 77c6a40..fb3e56b 100644 --- a/tools/metrics/histograms/metadata/ios/histograms.xml +++ b/tools/metrics/histograms/metadata/ios/histograms.xml
@@ -728,7 +728,7 @@ </summary> </histogram> -<histogram name="IOS.IPH.DefaultSite.Presented" units="BooleanHit" +<histogram name="IOS.IPH.DefaultSite.Presented" enum="BooleanHit" expires_after="2023-02-28"> <owner>gambard@chromium.org</owner> <owner>lpromero@chromium.org</owner> @@ -1443,7 +1443,7 @@ </summary> </histogram> -<histogram name="IOS.Snapshots.ImageSize" units="KB" expires_after="2022-07-03"> +<histogram name="IOS.Snapshots.ImageSize" units="KB" expires_after="2023-07-03"> <owner>ajuma@chromium.org</owner> <owner>edchin@chromium.org</owner> <summary>
diff --git a/tools/metrics/histograms/metadata/navigation/histograms.xml b/tools/metrics/histograms/metadata/navigation/histograms.xml index a48fc9f..4ad88f3 100644 --- a/tools/metrics/histograms/metadata/navigation/histograms.xml +++ b/tools/metrics/histograms/metadata/navigation/histograms.xml
@@ -1652,7 +1652,7 @@ </summary> </histogram> -<histogram name="Prerender.PrerenderLoadComplete" units="BooleanSuccess" +<histogram name="Prerender.PrerenderLoadComplete" enum="BooleanSuccess" expires_after="2022-10-16"> <owner>gambard@chromium.org</owner> <owner>justincohen@chromium.org</owner>
diff --git a/tools/metrics/histograms/metadata/net/histograms.xml b/tools/metrics/histograms/metadata/net/histograms.xml index 95588be..4899d02 100644 --- a/tools/metrics/histograms/metadata/net/histograms.xml +++ b/tools/metrics/histograms/metadata/net/histograms.xml
@@ -1521,7 +1521,7 @@ </summary> </histogram> -<histogram name="Net.ExpectCT.HeaderPresentOnResponse" units="BooleanPresent" +<histogram name="Net.ExpectCT.HeaderPresentOnResponse" enum="BooleanPresent" expires_after="2022-11-13"> <owner>estark@chromium.org</owner> <owner>trusty-transport@chromium.org</owner> @@ -4483,7 +4483,7 @@ </summary> </histogram> -<histogram name="Net.SpdySession.ServerSupportsWebSocket" units="Boolean" +<histogram name="Net.SpdySession.ServerSupportsWebSocket" enum="Boolean" expires_after="2023-05-11"> <owner>dschinazi@chromium.org</owner> <owner>src/net/OWNERS</owner>
diff --git a/tools/metrics/histograms/metadata/network/histograms.xml b/tools/metrics/histograms/metadata/network/histograms.xml index 052f6dc..17a5657 100644 --- a/tools/metrics/histograms/metadata/network/histograms.xml +++ b/tools/metrics/histograms/metadata/network/histograms.xml
@@ -2135,7 +2135,7 @@ </histogram> <histogram name="Network.Shill.WiFi.Hidden.EverConnected" - units="BooleanConnected" expires_after="2023-04-01"> + enum="BooleanConnected" expires_after="2023-04-01"> <owner>jonmann@chromium.org</owner> <owner>tnagel@chromium.org</owner> <owner>cros-network-metrics@google.com</owner>
diff --git a/tools/metrics/histograms/metadata/obsolete_histograms.xml b/tools/metrics/histograms/metadata/obsolete_histograms.xml index a600bd9..ff460d6e 100644 --- a/tools/metrics/histograms/metadata/obsolete_histograms.xml +++ b/tools/metrics/histograms/metadata/obsolete_histograms.xml
@@ -841,7 +841,7 @@ </summary> </histogram> -<histogram name="AnchorElementMetrics.IsAdFrameElement" units="Boolean" +<histogram name="AnchorElementMetrics.IsAdFrameElement" enum="Boolean" expires_after="2020-05-21"> <obsolete> Removed 05/2020. @@ -10280,7 +10280,7 @@ </histogram> <histogram name="Compositing.SurfaceDependencyDeadline.DeadlineHit" - units="Boolean" expires_after="2018-10-17"> + enum="Boolean" expires_after="2018-10-17"> <obsolete> Removed as of 10/2018. This metric didn't end up being useful. </obsolete> @@ -12742,7 +12742,7 @@ </histogram> <histogram name="DataReductionProxy.NetworkProperties.CacheHit" - units="BooleanCacheHit" expires_after="M81"> + enum="BooleanCacheHit" expires_after="M81"> <obsolete> Obsoleted in April 2020 </obsolete> @@ -16473,7 +16473,7 @@ </histogram> <histogram name="Download.DatabaseDownloadExistsForDownloadSlice" - units="Boolean" expires_after="M77"> + enum="Boolean" expires_after="M77"> <obsolete> Removed in 07/2019. </obsolete> @@ -22711,7 +22711,7 @@ </summary> </histogram> -<histogram name="Extensions.ExtensionFrameMapCacheHit" units="Boolean" +<histogram name="Extensions.ExtensionFrameMapCacheHit" enum="Boolean" expires_after="M85"> <obsolete> Removed in M77. @@ -22724,7 +22724,7 @@ </summary> </histogram> -<histogram name="Extensions.ExtensionFrameMapLookupSuccessful" units="Boolean" +<histogram name="Extensions.ExtensionFrameMapLookupSuccessful" enum="Boolean" expires_after="M85"> <obsolete> Removed in M77. @@ -22800,7 +22800,7 @@ </summary> </histogram> -<histogram name="Extensions.ExtensionRendererStateCacheHit" units="Boolean" +<histogram name="Extensions.ExtensionRendererStateCacheHit" enum="Boolean" expires_after="2016-04-05"> <obsolete> Removed 4/2016. ExtensionRendererState was replaced with ExtensionFrameMap. @@ -26273,7 +26273,7 @@ </summary> </histogram> -<histogram name="GPU.ProgramCache.CompressDataSuccess" units="BooleanSuccess" +<histogram name="GPU.ProgramCache.CompressDataSuccess" enum="BooleanSuccess" expires_after="M77"> <obsolete> Not used after M77. ProgramCache not actively being tuned. @@ -26314,7 +26314,7 @@ </summary> </histogram> -<histogram name="GPU.ProgramCache.DecompressDataSuccess" units="BooleanSuccess" +<histogram name="GPU.ProgramCache.DecompressDataSuccess" enum="BooleanSuccess" expires_after="M77"> <obsolete> Not used after M77. ProgramCache not actively being tuned. @@ -50608,7 +50608,7 @@ </histogram> <histogram name="OptimizationGuide.HintCache.HasHint.AtCommit" - units="BooleanAvailable" expires_after="2020-04-30"> + enum="BooleanAvailable" expires_after="2020-04-30"> <obsolete> Obsolete as of 04/2020 since this histogram no longer makes sense with the current request flow. @@ -50622,7 +50622,7 @@ </histogram> <histogram name="OptimizationGuide.HintCache.HasHint.BeforeCommit" - units="BooleanAvailable" expires_after="2020-04-30"> + enum="BooleanAvailable" expires_after="2020-04-30"> <obsolete> Obsolete as of 04/2020 since this histogram no longer makes sense with the current request flow. @@ -50637,7 +50637,7 @@ </histogram> <histogram name="OptimizationGuide.HintCache.HostMatch.AtCommit" - units="BooleanMatched" expires_after="2020-04-30"> + enum="BooleanMatched" expires_after="2020-04-30"> <obsolete> Obsolete as of 04/2020 since this histogram no longer makes sense with the current request flow. @@ -50654,7 +50654,7 @@ </histogram> <histogram name="OptimizationGuide.HintCache.PageMatch.AtCommit" - units="BooleanMatched" expires_after="2020-04-30"> + enum="BooleanMatched" expires_after="2020-04-30"> <obsolete> Obsolete as of 04/2020 since this histogram no longer makes sense with the current request flow. @@ -65102,7 +65102,7 @@ </histogram> <histogram name="RendererScheduler.UserModel.GesturePredictedCorrectly" - units="GesturePredictionResult" expires_after="2017-08-30"> + enum="GesturePredictionResult" expires_after="2017-08-30"> <obsolete> Removed from code 2017-08. </obsolete> @@ -67225,7 +67225,7 @@ </summary> </histogram> -<histogram name="SafeBrowsing.ModuleBaseRelocation" units="BaseRelocationType" +<histogram name="SafeBrowsing.ModuleBaseRelocation" enum="BaseRelocationType" expires_after="M85"> <obsolete> No longer used. Removed 2020-06. @@ -81923,7 +81923,7 @@ </summary> </histogram> -<histogram name="Sync.Preferences.RemotePrefTypeMismatch" units="BooleanHit" +<histogram name="Sync.Preferences.RemotePrefTypeMismatch" enum="BooleanHit" expires_after="M85"> <obsolete> Removed 2020-06. @@ -84700,7 +84700,7 @@ </summary> </histogram> -<histogram name="UI.CompositorResizeLock.TimedOut" units="Boolean" +<histogram name="UI.CompositorResizeLock.TimedOut" enum="Boolean" expires_after="2019-02-07"> <obsolete> Removed in M65.
diff --git a/tools/metrics/histograms/metadata/oobe/histograms.xml b/tools/metrics/histograms/metadata/oobe/histograms.xml index a5d4c9b7..5cb70b6f 100644 --- a/tools/metrics/histograms/metadata/oobe/histograms.xml +++ b/tools/metrics/histograms/metadata/oobe/histograms.xml
@@ -335,7 +335,7 @@ <summary>Time spent on specific OOBE screen grouped by exit reason.</summary> </histogram> -<histogram name="OOBE.StepShownStatus" units="BooleanShown" +<histogram name="OOBE.StepShownStatus" enum="BooleanShown" expires_after="never"> <!-- expires-never: Core metric for monitoring OOBE flow regressions. -->
diff --git a/tools/metrics/histograms/metadata/others/histograms.xml b/tools/metrics/histograms/metadata/others/histograms.xml index e237c994..53fcd1c 100644 --- a/tools/metrics/histograms/metadata/others/histograms.xml +++ b/tools/metrics/histograms/metadata/others/histograms.xml
@@ -5403,7 +5403,7 @@ </summary> </histogram> -<histogram name="ExploreSites.ImageDecoded" units="Boolean" +<histogram name="ExploreSites.ImageDecoded" enum="Boolean" expires_after="2020-06-30"> <owner>freedjm@chromium.org</owner> <owner>chrome-explore-team@google.com</owner> @@ -5433,7 +5433,7 @@ </summary> </histogram> -<histogram name="ExploreSites.NTPLoadingCatalogFromNetwork" units="Boolean" +<histogram name="ExploreSites.NTPLoadingCatalogFromNetwork" enum="Boolean" expires_after="2020-02-16"> <owner>dimich@chromium.org</owner> <summary> @@ -5501,7 +5501,7 @@ </histogram> <histogram name="Favicons.LargeIconService.BlacklistedURLMismatch" - units="BooleanError" expires_after="M77"> + enum="BooleanError" expires_after="M77"> <obsolete> Removed in M88. </obsolete> @@ -7608,7 +7608,7 @@ </summary> </histogram> -<histogram name="Mac.FileMenuNativeShare" units="BooleanSuccess" +<histogram name="Mac.FileMenuNativeShare" enum="BooleanSuccess" expires_after="2022-12-31"> <owner>nasims@google.com</owner> <owner>ellyjones@chromium.org</owner> @@ -8017,7 +8017,7 @@ </summary> </histogram> -<histogram name="Mojo.InvalidUTF8String" units="BooleanValid" +<histogram name="Mojo.InvalidUTF8String" enum="BooleanValid" expires_after="2023-04-01"> <owner>rsesek@chromium.org</owner> <owner>chrome-mojo@google.com</owner> @@ -8748,7 +8748,7 @@ </summary> </histogram> -<histogram name="OSX.NativeShare" units="BooleanSuccess" +<histogram name="OSX.NativeShare" enum="BooleanSuccess" expires_after="2021-09-05"> <obsolete> Removed July 2021, superceded by Mac.FileMenuNativeShare. @@ -9085,7 +9085,7 @@ </summary> </histogram> -<histogram name="Pepper.Graphics3DHasShareGroup" units="BooleanShareGroup" +<histogram name="Pepper.Graphics3DHasShareGroup" enum="BooleanShareGroup" expires_after="M77"> <owner>jbauman@chromium.org</owner> <summary> @@ -9814,7 +9814,7 @@ </histogram> <histogram name="Process.Sandbox.Launch.WarningResultCode" - units="LaunchErrorCodes" expires_after="never"> + enum="LaunchErrorCodes" expires_after="never"> <!-- expires-never: metric needed for diagnosing sandbox issues. --> <owner>forshaw@chromium.org</owner> @@ -9863,7 +9863,7 @@ </histogram> <histogram name="ProxyOverriddenBubble.UserSelection" - units="ExtensionBubbleAction" expires_after="2020-12-31"> + enum="ExtensionBubbleAction" expires_after="2020-12-31"> <obsolete> Code removed 2021/06. </obsolete> @@ -12045,7 +12045,7 @@ </summary> </histogram> -<histogram name="SiteIsolatedCodeCache.JS.Hit" units="Boolean" +<histogram name="SiteIsolatedCodeCache.JS.Hit" enum="Boolean" expires_after="2022-11-10"> <owner>yhirano@chromium.org</owner> <owner>loading-dev@chromium.org</owner> @@ -12066,7 +12066,7 @@ </histogram> <histogram name="SiteIsolatedCodeCache.JS.PotentialMemoryBackedCodeCacheHit" - units="Boolean" expires_after="2022-11-10"> + enum="Boolean" expires_after="2022-11-10"> <owner>yhirano@chromium.org</owner> <owner>loading-dev@chromium.org</owner> <summary>
diff --git a/tools/metrics/histograms/metadata/page/histograms.xml b/tools/metrics/histograms/metadata/page/histograms.xml index 6d4e293..fd3041b 100644 --- a/tools/metrics/histograms/metadata/page/histograms.xml +++ b/tools/metrics/histograms/metadata/page/histograms.xml
@@ -146,7 +146,7 @@ </summary> </histogram> -<histogram name="PageLoad.Clients.Ads.AdDensity.Recorded" units="Boolean" +<histogram name="PageLoad.Clients.Ads.AdDensity.Recorded" enum="Boolean" expires_after="2021-03-28"> <owner>justinmron@chromium.org</owner> <owner>johnidel@chromium.org</owner> @@ -1482,7 +1482,7 @@ </histogram> <histogram name="PageLoad.Experimental.Memory.Core.UpdateReceived" - units="BooleanReceived" expires_after="2021-09-30"> + enum="BooleanReceived" expires_after="2021-09-30"> <owner>cammie@chromium.org</owner> <owner>jkarlin@chromium.org</owner> <owner>johnidel@chromium.org</owner>
diff --git a/tools/metrics/histograms/metadata/password/histograms.xml b/tools/metrics/histograms/metadata/password/histograms.xml index d739ef4..1dccf59 100644 --- a/tools/metrics/histograms/metadata/password/histograms.xml +++ b/tools/metrics/histograms/metadata/password/histograms.xml
@@ -2904,7 +2904,7 @@ </histogram> <histogram name="PasswordManager.TouchToFill.ObservedSuccessfulSubmission" - units="Boolean" expires_after="2022-10-30"> + enum="Boolean" expires_after="2022-10-30"> <owner>ioanap@chromium.org</owner> <owner>fhorschig@chromium.org</owner> <owner>kolos@chromium.org</owner> @@ -3445,7 +3445,7 @@ </histogram> <histogram name="PasswordProtection.RequestWithToken.{TriggerType}" - units="BooleanSent" expires_after="2022-10-06"> + enum="BooleanSent" expires_after="2022-10-06"> <owner>bhatiarohit@google.com</owner> <owner>chrome-safebrowsing-alerts@google.com</owner> <summary> @@ -3456,7 +3456,7 @@ <token key="TriggerType" variants="PasswordProtectionTriggerType"/> </histogram> -<histogram name="PasswordProtection.SampleReportSent" units="Boolean" +<histogram name="PasswordProtection.SampleReportSent" enum="Boolean" expires_after="2022-10-31"> <owner>drubery@chromium.org</owner> <owner>chrome-safebrowsing-alerts@google.com</owner>
diff --git a/tools/metrics/histograms/metadata/permissions/histograms.xml b/tools/metrics/histograms/metadata/permissions/histograms.xml index 722e10d..f6fa1a0 100644 --- a/tools/metrics/histograms/metadata/permissions/histograms.xml +++ b/tools/metrics/histograms/metadata/permissions/histograms.xml
@@ -703,7 +703,7 @@ </histogram> <histogram name="Permissions.Prompt.Notifications.EnabledAppLevel" - units="Boolean" expires_after="M110"> + enum="Boolean" expires_after="M110"> <owner>engedy@chromium.org</owner> <owner>src/components/permissions/PERMISSIONS_OWNERS</owner> <summary>
diff --git a/tools/metrics/histograms/metadata/phonehub/histograms.xml b/tools/metrics/histograms/metadata/phonehub/histograms.xml index 7805dfa4..3fdd94a 100644 --- a/tools/metrics/histograms/metadata/phonehub/histograms.xml +++ b/tools/metrics/histograms/metadata/phonehub/histograms.xml
@@ -79,7 +79,7 @@ <token key="MediaType" variants="CameraRollMediaType"/> </histogram> -<histogram name="PhoneHub.CameraRoll.Content.Present" units="BooleanPresent" +<histogram name="PhoneHub.CameraRoll.Content.Present" enum="BooleanPresent" expires_after="2023-02-01"> <owner>jasonsun@chromium.org</owner> <owner>chromeos-cross-device-eng@google.com</owner>
diff --git a/tools/metrics/histograms/metadata/safe_browsing/histograms.xml b/tools/metrics/histograms/metadata/safe_browsing/histograms.xml index d8b20d7..d89708ad 100644 --- a/tools/metrics/histograms/metadata/safe_browsing/histograms.xml +++ b/tools/metrics/histograms/metadata/safe_browsing/histograms.xml
@@ -1347,7 +1347,7 @@ <histogram name="SafeBrowsing.NavigationObserver.MissingInitiatorRenderFrameHostPortal" - units="BooleanExists" expires_after="2022-07-23"> + enum="BooleanExists" expires_after="2022-07-23"> <owner>vollick@chromium.org</owner> <owner>chrome-safebrowsing-alerts@google.com</owner> <summary>Logs the number of times we have a missing initiator RFH.</summary> @@ -1395,7 +1395,7 @@ </summary> </histogram> -<histogram name="SafeBrowsing.PageLoadToken.HasExpired" units="BooleanExpired" +<histogram name="SafeBrowsing.PageLoadToken.HasExpired" enum="BooleanExpired" expires_after="2022-11-06"> <owner>xinghuilu@chromium.org</owner> <owner>chrome-safebrowsing-alerts@google.com</owner> @@ -1407,7 +1407,7 @@ </histogram> <histogram name="SafeBrowsing.PageLoadToken.PasswordProtectionHasToken" - units="BooleanExists" expires_after="2022-10-08"> + enum="BooleanExists" expires_after="2022-10-08"> <owner>xinghuilu@chromium.org</owner> <owner>chrome-safebrowsing-alerts@google.com</owner> <summary> @@ -1417,7 +1417,7 @@ </histogram> <histogram name="SafeBrowsing.PageLoadToken.RealTimeCheckHasToken" - units="BooleanExists" expires_after="2022-10-08"> + enum="BooleanExists" expires_after="2022-10-08"> <owner>xinghuilu@chromium.org</owner> <owner>chrome-safebrowsing-alerts@google.com</owner> <summary> @@ -1925,7 +1925,7 @@ </token> </histogram> -<histogram name="SafeBrowsing.RT.SampledRequestSent" units="Boolean" +<histogram name="SafeBrowsing.RT.SampledRequestSent" enum="Boolean" expires_after="2022-11-13"> <owner>zackhan@chromium.org</owner> <owner>chrome-safebrowsing-alerts@google.com</owner>
diff --git a/tools/metrics/histograms/metadata/sb_client/histograms.xml b/tools/metrics/histograms/metadata/sb_client/histograms.xml index 6eb7cc05..8f4540c6 100644 --- a/tools/metrics/histograms/metadata/sb_client/histograms.xml +++ b/tools/metrics/histograms/metadata/sb_client/histograms.xml
@@ -114,7 +114,7 @@ </summary> </histogram> -<histogram name="SBClientDownload.DocumentAnalysisSuccess" units="Boolean" +<histogram name="SBClientDownload.DocumentAnalysisSuccess" enum="Boolean" expires_after="2022-10-16"> <owner>drubery@chromium.org</owner> <owner>chrome-safebrowsing-alerts@google.com</owner>
diff --git a/tools/metrics/histograms/metadata/sharing/histograms.xml b/tools/metrics/histograms/metadata/sharing/histograms.xml index e8a8f8d..0be1597 100644 --- a/tools/metrics/histograms/metadata/sharing/histograms.xml +++ b/tools/metrics/histograms/metadata/sharing/histograms.xml
@@ -92,7 +92,7 @@ </summary> </histogram> -<histogram name="Sharing.ClickToCallPhoneNumberValid" units="BooleanValid" +<histogram name="Sharing.ClickToCallPhoneNumberValid" enum="BooleanValid" expires_after="M98"> <owner>knollr@chromium.org</owner> <owner>peter@chromium.org</owner>
diff --git a/tools/metrics/histograms/metadata/startup/histograms.xml b/tools/metrics/histograms/metadata/startup/histograms.xml index 1249386..176739f8 100644 --- a/tools/metrics/histograms/metadata/startup/histograms.xml +++ b/tools/metrics/histograms/metadata/startup/histograms.xml
@@ -68,7 +68,7 @@ <histogram name="Startup.Android.Cold.FirstNavigationCommitOccurredPreForeground" - units="Boolean" expires_after="2022-11-06"> + enum="Boolean" expires_after="2022-11-06"> <owner>blundell@chromium.org</owner> <owner>yfriedman@chromium.org</owner> <summary> @@ -79,7 +79,7 @@ </histogram> <histogram name="Startup.Android.Cold.FirstPaintOccurredPreForeground" - units="Boolean" expires_after="2022-11-06"> + enum="Boolean" expires_after="2022-11-06"> <owner>blundell@chromium.org</owner> <owner>yfriedman@chromium.org</owner> <summary> @@ -318,7 +318,7 @@ </summary> </histogram> -<histogram name="Startup.Android.StartSurfaceShownAtStartup" units="Boolean" +<histogram name="Startup.Android.StartSurfaceShownAtStartup" enum="Boolean" expires_after="2022-10-07"> <owner>hanxi@chromium.org</owner> <owner>spdonghao@chromium.org</owner>
diff --git a/tools/metrics/histograms/metadata/subresource/histograms.xml b/tools/metrics/histograms/metadata/subresource/histograms.xml index c2651fa..459d032 100644 --- a/tools/metrics/histograms/metadata/subresource/histograms.xml +++ b/tools/metrics/histograms/metadata/subresource/histograms.xml
@@ -182,7 +182,7 @@ </histogram> <histogram name="SubresourceFilter.CnameAlias.Renderer.WasBlockedBasedOnAlias" - units="BooleanBlocked" expires_after="2022-06-12"> + enum="BooleanBlocked" expires_after="2022-06-12"> <owner>cammie@chromium.org</owner> <owner>chrome-ads-histograms@google.com</owner> <summary>
diff --git a/tools/metrics/histograms/metadata/tab/histograms.xml b/tools/metrics/histograms/metadata/tab/histograms.xml index d0c6fd2..85537858 100644 --- a/tools/metrics/histograms/metadata/tab/histograms.xml +++ b/tools/metrics/histograms/metadata/tab/histograms.xml
@@ -147,8 +147,8 @@ </summary> </histogram> -<histogram name="Discarding.HighPMFPolicy.DiscardSuccess" - units="BooleanSuccess" expires_after="2021-12-06"> +<histogram name="Discarding.HighPMFPolicy.DiscardSuccess" enum="BooleanSuccess" + expires_after="2021-12-06"> <owner>chrisha@chromium.org</owner> <owner>catan-team@chromium.org</owner> <summary> @@ -363,7 +363,7 @@ </summary> </histogram> -<histogram name="Tab.CloseAllTabsDialog.ClosedAllTabs" units="Boolean" +<histogram name="Tab.CloseAllTabsDialog.ClosedAllTabs" enum="Boolean" expires_after="2022-09-11"> <obsolete> Removed 4/2022. Replaced with a metric split by incognito and non-incognito @@ -381,7 +381,7 @@ </histogram> <histogram name="Tab.CloseAllTabsDialog.ClosedAllTabs.{CloseType}" - units="Boolean" expires_after="2022-09-11"> + enum="Boolean" expires_after="2022-09-11"> <owner>ckitagawa@chromium.org</owner> <owner>fredmello@chromium.org</owner> <summary> @@ -1456,7 +1456,7 @@ </histogram> <histogram name="TabManager.SessionOverlap.BackgroundTabOpening" - units="BooleanOverlap" expires_after="M79"> + enum="BooleanOverlap" expires_after="M79"> <owner>chrisha@chromium.org</owner> <summary> Whether background tab opening session is overlapped with other types of @@ -1469,7 +1469,7 @@ </histogram> <histogram name="TabManager.SessionOverlap.SessionRestore" - units="BooleanOverlap" expires_after="M77"> + enum="BooleanOverlap" expires_after="M77"> <owner>chrisha@chromium.org</owner> <summary> Whether session restore is overlapped with other types of session, e.g.,
diff --git a/tools/metrics/histograms/metadata/update_engine/histograms.xml b/tools/metrics/histograms/metadata/update_engine/histograms.xml index ede1a5b..d5a9961 100644 --- a/tools/metrics/histograms/metadata/update_engine/histograms.xml +++ b/tools/metrics/histograms/metadata/update_engine/histograms.xml
@@ -768,7 +768,7 @@ </summary> </histogram> -<histogram name="UpdateEngine.UpdateInvalidated" units="BooleanSuccess" +<histogram name="UpdateEngine.UpdateInvalidated" enum="BooleanSuccess" expires_after="2023-01-01"> <owner>kimjae@chromium.org</owner> <owner>chromeos-core-services@google.com</owner>
diff --git a/tools/metrics/histograms/metadata/v8/histograms.xml b/tools/metrics/histograms/metadata/v8/histograms.xml index f7682ea6..74a479fa 100644 --- a/tools/metrics/histograms/metadata/v8/histograms.xml +++ b/tools/metrics/histograms/metadata/v8/histograms.xml
@@ -1906,7 +1906,7 @@ </summary> </histogram> -<histogram name="V8.WasmMemoryProtectionKeysSupport" units="BooleanSupported" +<histogram name="V8.WasmMemoryProtectionKeysSupport" enum="BooleanSupported" expires_after="2022-11-06"> <owner>clemensb@chromium.org</owner> <owner>jkummerow@chromium.org</owner>
diff --git a/tools/metrics/histograms/metadata/web_core/histograms.xml b/tools/metrics/histograms/metadata/web_core/histograms.xml index 5d7419d..e9b8349e 100644 --- a/tools/metrics/histograms/metadata/web_core/histograms.xml +++ b/tools/metrics/histograms/metadata/web_core/histograms.xml
@@ -610,7 +610,7 @@ </histogram> <histogram name="WebCore.IndexedDB.TombstoneSweeper.DeletionWriteError" - units="LevelDBStatus" expires_after="2022-08-01"> + enum="LevelDBStatus" expires_after="2022-08-01"> <owner>dmurph@chromium.org</owner> <owner>pwnall@chromium.org</owner> <owner>storage-dev@chromium.org</owner>
diff --git a/tools/metrics/histograms/metadata/windows/histograms.xml b/tools/metrics/histograms/metadata/windows/histograms.xml index 93ebf63c..3140e92 100644 --- a/tools/metrics/histograms/metadata/windows/histograms.xml +++ b/tools/metrics/histograms/metadata/windows/histograms.xml
@@ -321,7 +321,7 @@ </summary> </histogram> -<histogram name="Windows.TouchDrag.Success" units="BooleanSuccess" +<histogram name="Windows.TouchDrag.Success" enum="BooleanSuccess" expires_after="2021-12-01"> <owner>davidbienvenu@chromium.org</owner> <owner>dfried@chromium.org</owner>
diff --git a/tools/metrics/ukm/ukm.xml b/tools/metrics/ukm/ukm.xml index ab61caa5..b4cd5412 100644 --- a/tools/metrics/ukm/ukm.xml +++ b/tools/metrics/ukm/ukm.xml
@@ -12452,6 +12452,34 @@ </metric> </event> +<event name="Network.CacheTransparency"> + <owner>nidhijaju@chromium.org</owner> + <owner>ricea@chromium.org</owner> + <summary> + Metrics for Cache Transparency which aims to allow a list of pervasive + payloads to be specified and cached in a single-keyed cache to improve + loading performance. Recorded for each navigation and subresource load. + </summary> + <metric name="FoundPervasivePayload" enum="Boolean"> + <summary> + Records whether a requested resource's URL matched with any of the URLs in + the Cache Transparency Pervasive Payloads List. + </summary> + <aggregation> + <history> + <statistics> + <enumeration/> + </statistics> + </history> + </aggregation> + </metric> + <metric name="TotalBytesFetched"> + <summary> + Records the total bytes fetched by the network. + </summary> + </metric> +</event> + <event name="NoStatePrefetch" singular="True"> <owner>tbansal@chromium.org</owner> <summary>
diff --git a/tools/perf/core/perfetto_binary_roller/binary_deps.json b/tools/perf/core/perfetto_binary_roller/binary_deps.json index 499419d1..4af7de6 100644 --- a/tools/perf/core/perfetto_binary_roller/binary_deps.json +++ b/tools/perf/core/perfetto_binary_roller/binary_deps.json
@@ -14,7 +14,7 @@ }, "mac": { "hash": "9116123446ead8a37d72c37facdfb0f8be2c3f83", - "full_remote_path": "chromium-telemetry/perfetto_binaries/trace_processor_shell/mac/d1cb81f2aa43df0d60a387e49c7ca570b685ca7f/trace_processor_shell" + "full_remote_path": "chromium-telemetry/perfetto_binaries/trace_processor_shell/mac/5e0d3dbcc00516ba502fc5f9631cfd2136664489/trace_processor_shell" }, "mac_arm64": { "hash": "e1ad4861384b06d911a65f035317914b8cc975c6", @@ -22,7 +22,7 @@ }, "linux": { "hash": "2311f9c5c3b835a9c50908ce9274b711b9bf898a", - "full_remote_path": "chromium-telemetry/perfetto_binaries/trace_processor_shell/linux/d1cb81f2aa43df0d60a387e49c7ca570b685ca7f/trace_processor_shell" + "full_remote_path": "chromium-telemetry/perfetto_binaries/trace_processor_shell/linux/5e0d3dbcc00516ba502fc5f9631cfd2136664489/trace_processor_shell" } }, "power_profile.sql": {
diff --git a/ui/base/ime/dummy_text_input_client.cc b/ui/base/ime/dummy_text_input_client.cc index 3a550f8..063ab67 100644 --- a/ui/base/ime/dummy_text_input_client.cc +++ b/ui/base/ime/dummy_text_input_client.cc
@@ -164,7 +164,7 @@ } #endif -#if BUILDFLAG(IS_CHROMEOS_ASH) +#if BUILDFLAG(IS_CHROMEOS) gfx::Range DummyTextInputClient::GetAutocorrectRange() const { return autocorrect_range_; } @@ -179,7 +179,7 @@ } absl::optional<GrammarFragment> -DummyTextInputClient::GetGrammarFragmentAtCursor() { +DummyTextInputClient::GetGrammarFragmentAtCursor() const { for (const auto& fragment : grammar_fragments_) { if (fragment.range.Contains(cursor_range_)) { return fragment;
diff --git a/ui/base/ime/dummy_text_input_client.h b/ui/base/ime/dummy_text_input_client.h index 5a96ca9e..0b30b29 100644 --- a/ui/base/ime/dummy_text_input_client.h +++ b/ui/base/ime/dummy_text_input_client.h
@@ -69,11 +69,11 @@ const std::vector<ui::ImeTextSpan>& ui_ime_text_spans) override; #endif -#if BUILDFLAG(IS_CHROMEOS_ASH) +#if BUILDFLAG(IS_CHROMEOS) gfx::Range GetAutocorrectRange() const override; gfx::Rect GetAutocorrectCharacterBounds() const override; bool SetAutocorrectRange(const gfx::Range& range) override; - absl::optional<GrammarFragment> GetGrammarFragmentAtCursor() override; + absl::optional<GrammarFragment> GetGrammarFragmentAtCursor() const override; bool ClearGrammarFragments(const gfx::Range& range) override; bool AddGrammarFragments( const std::vector<GrammarFragment>& fragments) override;
diff --git a/ui/base/ime/fake_text_input_client.cc b/ui/base/ime/fake_text_input_client.cc index 69095ce..7b3c225 100644 --- a/ui/base/ime/fake_text_input_client.cc +++ b/ui/base/ime/fake_text_input_client.cc
@@ -174,7 +174,7 @@ } #endif -#if BUILDFLAG(IS_CHROMEOS_ASH) +#if BUILDFLAG(IS_CHROMEOS) gfx::Range FakeTextInputClient::GetAutocorrectRange() const { return autocorrect_range_; }
diff --git a/ui/base/ime/fake_text_input_client.h b/ui/base/ime/fake_text_input_client.h index 8a57015..5a7feae 100644 --- a/ui/base/ime/fake_text_input_client.h +++ b/ui/base/ime/fake_text_input_client.h
@@ -74,7 +74,7 @@ const gfx::Range& range, const std::vector<ui::ImeTextSpan>& ui_ime_text_spans) override; #endif -#if BUILDFLAG(IS_CHROMEOS_ASH) +#if BUILDFLAG(IS_CHROMEOS) gfx::Range GetAutocorrectRange() const override; gfx::Rect GetAutocorrectCharacterBounds() const override; bool SetAutocorrectRange(const gfx::Range& range) override;
diff --git a/ui/base/ime/text_input_client.cc b/ui/base/ime/text_input_client.cc index 26eb3922..b5275dc 100644 --- a/ui/base/ime/text_input_client.cc +++ b/ui/base/ime/text_input_client.cc
@@ -9,8 +9,9 @@ TextInputClient::~TextInputClient() { } -#if BUILDFLAG(IS_CHROMEOS_ASH) -absl::optional<GrammarFragment> TextInputClient::GetGrammarFragmentAtCursor() { +#if BUILDFLAG(IS_CHROMEOS) +absl::optional<GrammarFragment> TextInputClient::GetGrammarFragmentAtCursor() + const { return absl::nullopt; }
diff --git a/ui/base/ime/text_input_client.h b/ui/base/ime/text_input_client.h index 07209fc..7c1d74e 100644 --- a/ui/base/ime/text_input_client.h +++ b/ui/base/ime/text_input_client.h
@@ -55,7 +55,7 @@ FOCUS_REASON_OTHER, }; -#if BUILDFLAG(IS_CHROMEOS_ASH) +#if BUILDFLAG(IS_CHROMEOS) enum SubClass { kRenderWidgetHostViewAura = 0, kArcImeService = 1, @@ -248,7 +248,7 @@ const std::vector<ui::ImeTextSpan>& ui_ime_text_spans) = 0; #endif -#if BUILDFLAG(IS_CHROMEOS_ASH) +#if BUILDFLAG(IS_CHROMEOS) // Return the start and end index of the autocorrect range. If non-existent, // return an empty Range. virtual gfx::Range GetAutocorrectRange() const = 0; @@ -268,7 +268,7 @@ // Returns the grammar fragment which contains the current cursor. If // non-existent, returns nullopt. - virtual absl::optional<GrammarFragment> GetGrammarFragmentAtCursor(); + virtual absl::optional<GrammarFragment> GetGrammarFragmentAtCursor() const; // Clears all the grammar fragments in |range|, returns whether the operation // is successful. Should return true if the there is no fragment in the range.
diff --git a/ui/ozone/platform/wayland/host/wayland_toplevel_window.cc b/ui/ozone/platform/wayland/host/wayland_toplevel_window.cc index d8c7692..0de6e43 100644 --- a/ui/ozone/platform/wayland/host/wayland_toplevel_window.cc +++ b/ui/ozone/platform/wayland/host/wayland_toplevel_window.cc
@@ -408,6 +408,7 @@ gfx::Size size_in_dip = restored_size_dip().IsEmpty() ? GetBoundsInDIP().size() : restored_size_dip(); + bounds_dip.set_origin(gfx::Point(x, y)); bounds_dip.set_size(size_in_dip); } @@ -439,7 +440,13 @@ } void WaylandToplevelWindow::SetOrigin(const gfx::Point& origin) { - WaylandWindow::SetBoundsInDIP(gfx::Rect(origin, GetBoundsInDIP().size())); + // TODO(crbug.com/1306688): Using UpdateBoundsInDIP changes the size of the + // window due to the rounding. Change this to use SetBoundsInDIP when + // `bounds_px_` becomes `bounds_dip_`. + gfx::Point origin_px = + gfx::ScaleToFlooredPoint(origin, window_scale(), window_scale()); + WaylandWindow::SetBoundsInPixels( + gfx::Rect(origin_px, GetBoundsInPixels().size())); } void WaylandToplevelWindow::HandleSurfaceConfigure(uint32_t serial) {
diff --git a/ui/ozone/platform/wayland/host/wayland_window.cc b/ui/ozone/platform/wayland/host/wayland_window.cc index 7bbd6ae..acbd19e1 100644 --- a/ui/ozone/platform/wayland/host/wayland_window.cc +++ b/ui/ozone/platform/wayland/host/wayland_window.cc
@@ -37,7 +37,7 @@ #include "ui/ozone/platform/wayland/host/wayland_event_source.h" #include "ui/ozone/platform/wayland/host/wayland_frame_manager.h" #include "ui/ozone/platform/wayland/host/wayland_output_manager.h" -#include "ui/ozone/platform/wayland/host/wayland_pointer.h" +#include "ui/ozone/platform/wayland/host/wayland_screen.h" #include "ui/ozone/platform/wayland/host/wayland_subsurface.h" #include "ui/ozone/platform/wayland/host/wayland_surface.h" #include "ui/ozone/platform/wayland/host/wayland_zcr_cursor_shapes.h" @@ -134,7 +134,7 @@ // We need to keep DIP size of the window the same whenever the scale changes. if (update_bounds) - SetBoundsDip(gfx::ScaleToEnclosedRect(bounds_px_, 1.0 / old_scale)); + UpdateBoundsInDIP(gfx::ScaleToEnclosedRect(bounds_px_, 1.0 / old_scale)); // Propagate update to the child windows if (child_window_) @@ -588,7 +588,7 @@ std::move(drag_loop_quit_closure_).Run(); } -void WaylandWindow::SetBoundsDip(const gfx::Rect& bounds_dip) { +void WaylandWindow::UpdateBoundsInDIP(const gfx::Rect& bounds_dip) { // This method is used to update the content size by calling WindowWindow's // SetBounds, instead of WaylandToplevelWindow's override, which sends a // request to the compositor. @@ -651,6 +651,8 @@ } void WaylandWindow::OnEnteredOutput() { + delegate()->OnMovedToAnotherDisplay(); + // Wayland does weird things for menus so instead of tracking outputs that // we entered or left, we take that from the parent window and ignore this // event. @@ -666,8 +668,8 @@ // event. if (AsWaylandPopup()) return; - - UpdateWindowScale(true); + // Do not update the window scale where. It'll be updated when entring a new + // output. } void WaylandWindow::UpdateCursorPositionFromEvent(const Event* orig_event) { @@ -695,10 +697,10 @@ auto* toplevel_window = GetRootParentWindow(); if (toplevel_window != this) { event = Event::Clone(*orig_event); - // TODO(crbug.com/1306688): This should use DIP. ConvertEventLocationToTargetWindowLocation( - toplevel_window->GetBoundsInPixels().origin(), - GetBoundsInPixels().origin(), event->AsLocatedEvent()); + toplevel_window->GetBoundsInDIP().origin(), GetBoundsInDIP().origin(), + event->AsLocatedEvent()); + located_event = event->AsLocatedEvent(); } @@ -1064,7 +1066,7 @@ DCHECK(!pending_configures_.empty()); for (auto& configure : pending_configures_) configure.set = true; - SetBoundsDip(pending_configures_.back().bounds_dip); + UpdateBoundsInDIP(pending_configures_.back().bounds_dip); } bool WaylandWindow::HasPendingConfigures() const {
diff --git a/ui/ozone/platform/wayland/host/wayland_window.h b/ui/ozone/platform/wayland/host/wayland_window.h index f98fb97..4c76e61 100644 --- a/ui/ozone/platform/wayland/host/wayland_window.h +++ b/ui/ozone/platform/wayland/host/wayland_window.h
@@ -325,8 +325,9 @@ const WaylandConnection* connection() const { return connection_; } PlatformWindowDelegate* delegate() { return delegate_; } - // [Deprecatd] Sets bounds in dip. This will be replaced with SetBoundsInDIP. - void SetBoundsDip(const gfx::Rect& bounds_dip); + // Update the bounds of the window in DIP. Unlike SetBoundInDIP, it will not + // send a request to the compositor even if the screen coordinate is enabled. + void UpdateBoundsInDIP(const gfx::Rect& bounds_dip); void set_ui_scale(float ui_scale) { ui_scale_ = ui_scale; }
diff --git a/ui/ozone/platform/wayland/host/wayland_window_drag_controller.cc b/ui/ozone/platform/wayland/host/wayland_window_drag_controller.cc index 5ca82aa..0c30d23 100644 --- a/ui/ozone/platform/wayland/host/wayland_window_drag_controller.cc +++ b/ui/ozone/platform/wayland/host/wayland_window_drag_controller.cc
@@ -21,6 +21,8 @@ #include "base/run_loop.h" #include "base/task/current_thread.h" #include "ui/base/dragdrop/drag_drop_types.h" +#include "ui/display/display.h" +#include "ui/display/screen.h" #include "ui/events/event.h" #include "ui/events/event_constants.h" #include "ui/events/platform/platform_event_dispatcher.h" @@ -38,6 +40,8 @@ #include "ui/ozone/platform/wayland/host/wayland_data_device_manager.h" #include "ui/ozone/platform/wayland/host/wayland_data_offer.h" #include "ui/ozone/platform/wayland/host/wayland_data_source.h" +#include "ui/ozone/platform/wayland/host/wayland_output_manager.h" +#include "ui/ozone/platform/wayland/host/wayland_screen.h" #include "ui/ozone/platform/wayland/host/wayland_serial_tracker.h" #include "ui/ozone/platform/wayland/host/wayland_surface.h" #include "ui/ozone/platform/wayland/host/wayland_window.h" @@ -420,8 +424,7 @@ return; DCHECK(window); - // TODO(crbug.com/1306688): This should use DIP. - auto origin = window->GetBoundsInPixels().origin(); + auto origin = window->GetBoundsInDIP().origin(); gfx::Vector2d offset = gfx::ToFlooredPoint(pointer_location_) - origin; DVLOG(1) << "Toplevel window created (detached)." << " widget=" << window->GetWidget() @@ -452,10 +455,13 @@ return; // Update current cursor position relative to the event source - // (pointer_grab_owner_) so it can be retrieved later on through + // (focused window) so it can be retrieved later on through // |Screen::GetCursorScreenPoint| API. - if (pointer_grab_owner_) - pointer_grab_owner_->UpdateCursorPositionFromEvent(event); + auto* pointer_focused_window = connection_->wayland_window_manager() + ->GetCurrentPointerOrTouchFocusedWindow(); + + if (pointer_focused_window) + pointer_focused_window->UpdateCursorPositionFromEvent(event); // Notify listeners about window bounds change (i.e: re-positioning) event. // To do so, set the new bounds as per the motion event location and the drag @@ -463,10 +469,9 @@ // surface has no visual effect in ozone/wayland backend. Actual window // re-positioning during dragging session is done through the drag icon. if (dragged_window_) { - // TODO(crbug.com/1306688): This should use DIP. gfx::Point new_location = event->location() - drag_offset_; - gfx::Size size = dragged_window_->GetBoundsInPixels().size(); - dragged_window_->SetBoundsInPixels({new_location, size}); + gfx::Size size = dragged_window_->GetBoundsInDIP().size(); + dragged_window_->SetBoundsInDIP({new_location, size}); } should_process_drag_event_ = false;
diff --git a/ui/ozone/platform/wayland/host/wayland_window_drag_controller_unittest.cc b/ui/ozone/platform/wayland/host/wayland_window_drag_controller_unittest.cc index f06fb652..630bf22 100644 --- a/ui/ozone/platform/wayland/host/wayland_window_drag_controller_unittest.cc +++ b/ui/ozone/platform/wayland/host/wayland_window_drag_controller_unittest.cc
@@ -1109,6 +1109,7 @@ WmMoveLoopHandler* move_loop_handler, bool in_pixel_coordinates) { for (auto* output : *outputs) { + SCOPED_TRACE(base::StringPrintf("Output Scale=%d", output->GetScale())); gfx::Point p0{10, 10}; // Compute the expected point first as drag operation will move the // window.
diff --git a/ui/ozone/platform/wayland/host/wayland_window_unittest.cc b/ui/ozone/platform/wayland/host/wayland_window_unittest.cc index e7197ad..088e9f5c 100644 --- a/ui/ozone/platform/wayland/host/wayland_window_unittest.cc +++ b/ui/ozone/platform/wayland/host/wayland_window_unittest.cc
@@ -3132,6 +3132,16 @@ Sync(); } +TEST_P(WaylandWindowTest, InitialBounds) { + testing::NiceMock<MockWaylandPlatformWindowDelegate> delegate_2; + auto toplevel = CreateWaylandWindowWithParams( + PlatformWindowType::kWindow, 0, gfx::Rect(10, 10, 200, 200), &delegate_2); + toplevel->HandleAuraToplevelConfigure(20, 20, 0, 0, false, false, true); + toplevel->HandleSurfaceConfigure(2); + static_cast<WaylandToplevelWindow*>(toplevel.get())->ApplyPendingBounds(); + EXPECT_EQ(gfx::Rect(20, 20, 200, 200), toplevel->GetBoundsInDIP()); +} + namespace { class WaylandSubsurfaceTest : public WaylandWindowTest {
diff --git a/ui/platform_window/platform_window_delegate.cc b/ui/platform_window/platform_window_delegate.cc index ea10ab5..786619c7 100644 --- a/ui/platform_window/platform_window_delegate.cc +++ b/ui/platform_window/platform_window_delegate.cc
@@ -55,6 +55,8 @@ return rect_in_dip; } +void PlatformWindowDelegate::OnMovedToAnotherDisplay() {} + gfx::Rect PlatformWindowDelegate::ConvertRectToDIP( const gfx::Rect& rect_in_pixels) const { return rect_in_pixels;
diff --git a/ui/platform_window/platform_window_delegate.h b/ui/platform_window/platform_window_delegate.h index 1bae07b..e1f168b 100644 --- a/ui/platform_window/platform_window_delegate.h +++ b/ui/platform_window/platform_window_delegate.h
@@ -140,6 +140,9 @@ // Enables or disables frame rate throttling. virtual void SetFrameRateThrottleEnabled(bool enabled); + // Called when the platform window is moved to another display. + virtual void OnMovedToAnotherDisplay(); + // Convert gfx::Rect in pixels to DIP in screen, and vice versa. virtual gfx::Rect ConvertRectToPixels(const gfx::Rect& rect_in_dp) const; virtual gfx::Rect ConvertRectToDIP(const gfx::Rect& rect_in_pixells) const;
diff --git a/ui/views/controls/menu/menu_scroll_view_container.cc b/ui/views/controls/menu/menu_scroll_view_container.cc index a3354482..1ffec70 100644 --- a/ui/views/controls/menu/menu_scroll_view_container.cc +++ b/ui/views/controls/menu/menu_scroll_view_container.cc
@@ -456,8 +456,7 @@ if (ash::features::IsDarkLightModeEnabled()) { background_view_->SetBorder(std::make_unique<HighlightBorder>( corner_radius_, HighlightBorder::Type::kHighlightBorder1, - /*use_light_colors=*/false, - views::HighlightBorder::InsetsType::kHalfInsets)); + /*use_light_colors=*/false)); } #endif } else {
diff --git a/ui/views/controls/prefix_selector.cc b/ui/views/controls/prefix_selector.cc index 517f429..67c49f3 100644 --- a/ui/views/controls/prefix_selector.cc +++ b/ui/views/controls/prefix_selector.cc
@@ -181,7 +181,7 @@ } #endif -#if BUILDFLAG(IS_CHROMEOS_ASH) +#if BUILDFLAG(IS_CHROMEOS) gfx::Range PrefixSelector::GetAutocorrectRange() const { NOTIMPLEMENTED_LOG_ONCE(); return gfx::Range();
diff --git a/ui/views/controls/prefix_selector.h b/ui/views/controls/prefix_selector.h index a444ecb..f132b73ea 100644 --- a/ui/views/controls/prefix_selector.h +++ b/ui/views/controls/prefix_selector.h
@@ -87,7 +87,7 @@ const std::vector<ui::ImeTextSpan>& ui_ime_text_spans) override; #endif -#if BUILDFLAG(IS_CHROMEOS_ASH) +#if BUILDFLAG(IS_CHROMEOS) gfx::Range GetAutocorrectRange() const override; gfx::Rect GetAutocorrectCharacterBounds() const override; bool SetAutocorrectRange(const gfx::Range& range) override;
diff --git a/ui/views/controls/textfield/textfield.cc b/ui/views/controls/textfield/textfield.cc index 985f989..4699148d 100644 --- a/ui/views/controls/textfield/textfield.cc +++ b/ui/views/controls/textfield/textfield.cc
@@ -1743,7 +1743,7 @@ } #endif -#if BUILDFLAG(IS_CHROMEOS_ASH) +#if BUILDFLAG(IS_CHROMEOS) gfx::Range Textfield::GetAutocorrectRange() const { return model_->autocorrect_range(); }
diff --git a/ui/views/controls/textfield/textfield.h b/ui/views/controls/textfield/textfield.h index fdc2db0..ea29dbdf 100644 --- a/ui/views/controls/textfield/textfield.h +++ b/ui/views/controls/textfield/textfield.h
@@ -449,7 +449,7 @@ const std::vector<ui::ImeTextSpan>& ui_ime_text_spans) override; #endif -#if BUILDFLAG(IS_CHROMEOS_ASH) +#if BUILDFLAG(IS_CHROMEOS) gfx::Range GetAutocorrectRange() const override; gfx::Rect GetAutocorrectCharacterBounds() const override; bool SetAutocorrectRange(const gfx::Range& range) override;
diff --git a/ui/views/controls/textfield/textfield_model.cc b/ui/views/controls/textfield/textfield_model.cc index e6a8ea58..b3202c3 100644 --- a/ui/views/controls/textfield/textfield_model.cc +++ b/ui/views/controls/textfield/textfield_model.cc
@@ -762,7 +762,7 @@ } } -#if BUILDFLAG(IS_CHROMEOS_ASH) +#if BUILDFLAG(IS_CHROMEOS) bool TextfieldModel::SetAutocorrectRange(const gfx::Range& range) { // TODO(crbug.com/1108170): Add an underline to |range|. if (range.GetMax() > render_text()->text().length()) {
diff --git a/ui/views/controls/textfield/textfield_model.h b/ui/views/controls/textfield/textfield_model.h index 6130f94..01452147 100644 --- a/ui/views/controls/textfield/textfield_model.h +++ b/ui/views/controls/textfield/textfield_model.h
@@ -239,7 +239,7 @@ // composition text. void SetCompositionText(const ui::CompositionText& composition); -#if BUILDFLAG(IS_CHROMEOS_ASH) +#if BUILDFLAG(IS_CHROMEOS) // Return the text range corresponding to the autocorrected text. const gfx::Range& autocorrect_range() const { return autocorrect_range_; } @@ -340,7 +340,7 @@ gfx::Range composition_range_; -#if BUILDFLAG(IS_CHROMEOS_ASH) +#if BUILDFLAG(IS_CHROMEOS) gfx::Range autocorrect_range_; #endif
diff --git a/ui/views/widget/desktop_aura/desktop_native_widget_aura.cc b/ui/views/widget/desktop_aura/desktop_native_widget_aura.cc index 12fa05a..85c878b 100644 --- a/ui/views/widget/desktop_aura/desktop_native_widget_aura.cc +++ b/ui/views/widget/desktop_aura/desktop_native_widget_aura.cc
@@ -537,7 +537,6 @@ void DesktopNativeWidgetAura::InitNativeWidget(Widget::InitParams params) { ownership_ = params.ownership; widget_type_ = params.type; - headless_mode_ = params.headless_mode; name_ = params.name; content_window_->AcquireAllPropertiesFrom(
diff --git a/ui/views/widget/desktop_aura/desktop_native_widget_aura.h b/ui/views/widget/desktop_aura/desktop_native_widget_aura.h index 5f578ef..f2560e5 100644 --- a/ui/views/widget/desktop_aura/desktop_native_widget_aura.h +++ b/ui/views/widget/desktop_aura/desktop_native_widget_aura.h
@@ -113,9 +113,6 @@ // DesktopWindowTreeHost's transparency. void UpdateWindowTransparency(); - // Returns true if the desktop window was created in headless mode. - bool IsHeadlessMode() const { return headless_mode_; } - protected: // internal::NativeWidgetPrivate: void InitNativeWidget(Widget::InitParams params) override; @@ -347,9 +344,6 @@ // See class documentation for Widget in widget.h for a note about type. Widget::InitParams::Type widget_type_; - // Set if the desktop window was created in headless mode. - bool headless_mode_ = false; - // See DesktopWindowTreeHost::ShouldUseDesktopNativeCursorManager(). bool use_desktop_native_cursor_manager_ = false;
diff --git a/ui/views/widget/desktop_aura/desktop_window_tree_host_platform.cc b/ui/views/widget/desktop_aura/desktop_window_tree_host_platform.cc index d207fe7..c8207771 100644 --- a/ui/views/widget/desktop_aura/desktop_window_tree_host_platform.cc +++ b/ui/views/widget/desktop_aura/desktop_window_tree_host_platform.cc
@@ -891,6 +891,10 @@ return window_anchor; } +void DesktopWindowTreeHostPlatform::OnMovedToAnotherDisplay() { + WindowTreeHost::OnHostResizedInPixels(GetBoundsInPixels().size()); +} + gfx::Rect DesktopWindowTreeHostPlatform::ConvertRectToPixels( const gfx::Rect& rect_in_dip) const { return ToPixelRect(rect_in_dip);
diff --git a/ui/views/widget/desktop_aura/desktop_window_tree_host_platform.h b/ui/views/widget/desktop_aura/desktop_window_tree_host_platform.h index 377fd7a0..d8fdd0f8 100644 --- a/ui/views/widget/desktop_aura/desktop_window_tree_host_platform.h +++ b/ui/views/widget/desktop_aura/desktop_window_tree_host_platform.h
@@ -152,6 +152,7 @@ absl::optional<ui::MenuType> GetMenuType() override; absl::optional<ui::OwnedWindowAnchor> GetOwnedWindowAnchorAndRectInPx() override; + void OnMovedToAnotherDisplay() override; gfx::Rect ConvertRectToPixels(const gfx::Rect& rect_in_dip) const override; gfx::Rect ConvertRectToDIP(const gfx::Rect& rect_in_pixels) const override; gfx::PointF ConvertScreenPointToLocalDIP(
diff --git a/ui/views/widget/desktop_aura/desktop_window_tree_host_win.cc b/ui/views/widget/desktop_aura/desktop_window_tree_host_win.cc index 7ee93666..d7260950 100644 --- a/ui/views/widget/desktop_aura/desktop_window_tree_host_win.cc +++ b/ui/views/widget/desktop_aura/desktop_window_tree_host_win.cc
@@ -191,7 +191,7 @@ // We don't have an HWND yet, so scale relative to the nearest screen. gfx::Rect pixel_bounds = display::win::ScreenWin::DIPToScreenRect(nullptr, params.bounds); - message_handler_->Init(parent_hwnd, pixel_bounds); + message_handler_->Init(parent_hwnd, pixel_bounds, params.headless_mode); CreateCompositor(params.force_software_compositing); OnAcceleratedWidgetAvailable(); InitHost(); @@ -802,10 +802,6 @@ return native_widget_delegate_->IsModal(); } -bool DesktopWindowTreeHostWin::IsHeadless() const { - return desktop_native_widget_aura_->IsHeadlessMode(); -} - int DesktopWindowTreeHostWin::GetInitialShowState() const { return CanActivate() ? SW_SHOWNORMAL : SW_SHOWNOACTIVATE; }
diff --git a/ui/views/widget/desktop_aura/desktop_window_tree_host_win.h b/ui/views/widget/desktop_aura/desktop_window_tree_host_win.h index 7abd8d3..0aae49e 100644 --- a/ui/views/widget/desktop_aura/desktop_window_tree_host_win.h +++ b/ui/views/widget/desktop_aura/desktop_window_tree_host_win.h
@@ -206,7 +206,6 @@ bool WantsMouseEventsWhenInactive() const override; bool WidgetSizeIsClientSize() const override; bool IsModal() const override; - bool IsHeadless() const override; int GetInitialShowState() const override; int GetNonClientComponent(const gfx::Point& point) const override; void GetWindowMask(const gfx::Size& size, SkPath* path) override;
diff --git a/ui/views/win/hwnd_message_handler.cc b/ui/views/win/hwnd_message_handler.cc index 17e7850..2d21f2e 100644 --- a/ui/views/win/hwnd_message_handler.cc +++ b/ui/views/win/hwnd_message_handler.cc
@@ -434,12 +434,19 @@ ClearUserData(); } -void HWNDMessageHandler::Init(HWND parent, const gfx::Rect& bounds) { +void HWNDMessageHandler::Init(HWND parent, + const gfx::Rect& bounds, + bool headless_mode) { TRACE_EVENT0("views", "HWNDMessageHandler::Init"); GetMonitorAndRects(bounds.ToRECT(), &last_monitor_, &last_monitor_rect_, &last_work_area_); initial_bounds_valid_ = !bounds.IsEmpty(); + + // Provide the headless mode window state container. + if (headless_mode) + headless_mode_window_ = absl::make_optional<HeadlessModeWindow>(); + // Create the window. WindowImpl::Init(parent, bounds); @@ -660,7 +667,7 @@ // showing it just maintain a local flag to track the expected headless // window visibility state. if (IsHeadless()) { - headless_window_visibility_state_ = true; + headless_mode_window_->visibility_state = true; return; } @@ -738,7 +745,7 @@ // hiding it just maintain a local flag to track the expected headless // window visibility state. if (IsHeadless()) { - headless_window_visibility_state_ = false; + headless_mode_window_->visibility_state = false; return; } @@ -754,15 +761,30 @@ } void HWNDMessageHandler::Maximize() { + if (IsHeadless()) { + headless_mode_window_->minmax_state = HeadlessModeWindow::kMaximized; + return; + } + ExecuteSystemMenuCommand(SC_MAXIMIZE); } void HWNDMessageHandler::Minimize() { + if (IsHeadless()) { + headless_mode_window_->minmax_state = HeadlessModeWindow::kMinimized; + return; + } + ExecuteSystemMenuCommand(SC_MINIMIZE); delegate_->HandleNativeBlur(nullptr); } void HWNDMessageHandler::Restore() { + if (IsHeadless()) { + headless_mode_window_->minmax_state = HeadlessModeWindow::kNormal; + return; + } + ExecuteSystemMenuCommand(SC_RESTORE); } @@ -797,7 +819,7 @@ // In headless mode the platform window is always hidden, so instead of // returning the actual window visibility state return the expected visibility // state maintained by Show/Hide() calls. - return IsHeadless() ? headless_window_visibility_state_ + return IsHeadless() ? headless_mode_window_->visibility_state : !!::IsWindowVisible(hwnd()); } @@ -806,17 +828,20 @@ } bool HWNDMessageHandler::IsMinimized() const { - return !!::IsIconic(hwnd()); + return IsHeadless() ? headless_mode_window_->IsMinimized() + : !!::IsIconic(hwnd()); } bool HWNDMessageHandler::IsMaximized() const { - return !!::IsZoomed(hwnd()) && !IsFullscreen(); + return (IsHeadless() ? headless_mode_window_->IsMaximized() + : !!::IsZoomed(hwnd())) && + !IsFullscreen(); } bool HWNDMessageHandler::IsFullscreen() const { // In headless mode report the requested window state instead of the actual // one. - return IsHeadless() ? headless_window_fullscreen_state_ + return IsHeadless() ? headless_mode_window_->fullscreen_state : fullscreen_handler_->fullscreen(); } @@ -825,7 +850,7 @@ } bool HWNDMessageHandler::IsHeadless() const { - return delegate_->IsHeadless(); + return headless_mode_window_.has_value(); } bool HWNDMessageHandler::RunMoveLoop(const gfx::Vector2d& drag_offset, @@ -943,7 +968,7 @@ // Avoid setting fullscreen mode when in headless mode, but keep track // of the requested state for IsFullscreen() to report. if (IsHeadless()) { - headless_window_fullscreen_state_ = fullscreen; + headless_mode_window_->fullscreen_state = fullscreen; return; }
diff --git a/ui/views/win/hwnd_message_handler.h b/ui/views/win/hwnd_message_handler.h index a46da3e..73337f6b 100644 --- a/ui/views/win/hwnd_message_handler.h +++ b/ui/views/win/hwnd_message_handler.h
@@ -21,6 +21,7 @@ #include "base/scoped_observation.h" #include "base/win/scoped_gdi_object.h" #include "base/win/win_util.h" +#include "third_party/abseil-cpp/absl/types/optional.h" #include "ui/accessibility/platform/ax_fragment_root_delegate_win.h" #include "ui/base/ime/input_method.h" #include "ui/base/ime/input_method_observer.h" @@ -94,7 +95,7 @@ ~HWNDMessageHandler() override; - void Init(HWND parent, const gfx::Rect& bounds); + void Init(HWND parent, const gfx::Rect& bounds, bool headless_mode); void InitModalType(ui::ModalType modal_type); void Close(); @@ -811,11 +812,20 @@ // call HandleWindowMinimizedOrRestored() when we get a WM_ACTIVATE message. bool notify_restore_on_activate_ = false; - // These track headless window visibility and fullscreen states. In headless - // mode the platform window is never made visible or fullscreen, so we use - // these booleans to track the requested window state. - bool headless_window_visibility_state_ = false; - bool headless_window_fullscreen_state_ = false; + // This tracks headless window visibility, fullscreen and min/max states. In + // headless mode the platform window is never made visible or change its + // state, so this structure holds the requested state for reporting. + struct HeadlessModeWindow { + bool IsMinimized() const { return minmax_state == kMinimized; } + bool IsMaximized() const { return minmax_state == kMaximized; } + + bool visibility_state = false; + bool fullscreen_state = false; + enum { kNormal, kMinimized, kMaximized } minmax_state = kNormal; + }; + + // This is present iff the window has been created in headless mode. + absl::optional<HeadlessModeWindow> headless_mode_window_; // This is a map of the HMONITOR to full screeen window instance. It is safe // to keep a raw pointer to the HWNDMessageHandler instance as we track the
diff --git a/ui/views/win/hwnd_message_handler_delegate.h b/ui/views/win/hwnd_message_handler_delegate.h index caff85f..d8e0f1d 100644 --- a/ui/views/win/hwnd_message_handler_delegate.h +++ b/ui/views/win/hwnd_message_handler_delegate.h
@@ -78,9 +78,6 @@ // Returns true if the delegate represents a modal window. virtual bool IsModal() const = 0; - // Returns true if the delegate represents a headless window. - virtual bool IsHeadless() const = 0; - // Returns the show state that should be used for the application's first // window. virtual int GetInitialShowState() const = 0;